forked from google-research/google-research
-
Notifications
You must be signed in to change notification settings - Fork 0
/
q_graph.py
563 lines (471 loc) · 22.1 KB
/
q_graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
# coding=utf-8
# Copyright 2020 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Continuous Q-Learning via random sampling."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl import logging
import enum
import gin
import tensorflow.compat.v1 as tf
class DQNTarget(enum.Enum):
"""Enum constants for DQN target network variants.
Attributes:
notarget: No target network used. Next-step action-value computed using
using online Q network.
normal: Target network used to select action and evaluate next-step
action-value.
doubleq: Double-Q Learning as proposed by https://arxiv.org/abs/1509.06461.
Action is selected by online Q network but evaluated using target network.
"""
notarget = 'notarget'
normal = 'normal'
doubleq = 'doubleq'
gin.constant('DQNTarget.notarget', DQNTarget.notarget)
gin.constant('DQNTarget.normal', DQNTarget.normal)
gin.constant('DQNTarget.doubleq', DQNTarget.doubleq)
@gin.configurable
def discrete_q_graph(q_func,
transition,
target_network_type=DQNTarget.normal,
gamma=1.0,
loss_fn=tf.losses.huber_loss,
extra_callback=None):
"""Construct loss/summary graph for discrete Q-Learning (DQN).
This Q-function loss implementation is derived from OpenAI baselines.
This function supports dynamic batch sizes.
Args:
q_func: Python function that takes in state, scope as input
and returns a tensor Q(a_0...a_N) for each action a_0...a_N, and
intermediate endpoints dictionary.
transition: SARSTransition namedtuple.
target_network_type: Option to use Q Learning without target network, Q
Learning with a target network (default), or Double-Q Learning with a
target network.
gamma: Discount factor.
loss_fn: Function that computes the td_loss tensor. Takes as arguments
(target value tensor, predicted value tensor).
extra_callback: Optional function that takes in (transition, end_points_t,
end_points_tp1) and adds additional TF graph elements.
Returns:
A tuple (loss, summaries) where loss is a scalar loss tensor to minimize,
summaries are TensorFlow summaries.
"""
state = transition.state
action = transition.action
state_p1 = transition.state_p1
reward = transition.reward
done = transition.done
q_t, end_points_t = q_func(state, scope='q_func')
num_actions = q_t.get_shape().as_list()[1]
q_t_selected = tf.reduce_sum(q_t * tf.one_hot(action, num_actions), 1)
if gamma != 0:
if target_network_type == DQNTarget.notarget:
# Evaluate target values using the current net only.
q_tp1_using_online_net, end_points_tp1 = q_func(
state_p1, scope='q_func', reuse=True)
q_tp1_best = tf.reduce_max(q_tp1_using_online_net, 1)
elif target_network_type == DQNTarget.normal:
# Target network Q values at t+1.
q_tp1_target, end_points_tp1 = q_func(state_p1, scope='target_q_func')
q_tp1_best = tf.reduce_max(q_tp1_target, 1)
elif target_network_type == DQNTarget.doubleq:
q_tp1_target, end_points_tp1 = q_func(state_p1, scope='target_q_func')
# Q values at t+1.
q_tp1_using_online_net, _ = q_func(state_p1, scope='q_func', reuse=True)
# Q values for action we select at t+1.
q_tp1_best_using_online_net = tf.one_hot(
tf.argmax(q_tp1_using_online_net, 1), num_actions)
# Q value of target network at t+1, but for action selected by online net.
q_tp1_best = tf.reduce_sum(q_tp1_target * q_tp1_best_using_online_net, 1)
else:
logging.error('Invalid target_network_mode %s', target_network_type)
q_tp1_best_masked = (1.0 - done) * q_tp1_best
q_t_selected_target = tf.stop_gradient(reward + gamma * q_tp1_best_masked)
else:
q_t_selected_target = tf.stop_gradient(reward)
td_error = q_t_selected - q_t_selected_target
if extra_callback is not None:
extra_callback(transition, end_points_t, end_points_tp1)
tf.summary.histogram('td_error', td_error)
tf.summary.histogram('q_t_selected', q_t_selected)
tf.summary.histogram('q_t_selected_target', q_t_selected_target)
tf.summary.scalar('mean_q_t_selected', tf.reduce_mean(q_t_selected))
td_loss = loss_fn(q_t_selected_target, q_t_selected)
tf.summary.scalar('td_loss', td_loss)
reg_loss = tf.losses.get_regularization_loss()
tf.summary.scalar('reg_loss', reg_loss)
loss = tf.losses.get_total_loss()
tf.summary.scalar('total_loss', loss)
summaries = tf.summary.merge_all()
return loss, summaries
@gin.configurable
def random_sample_box(batch_size,
action_size,
num_samples,
minval=-1.,
maxval=1.):
"""Samples actions for each batch element uniformly from a hyperrectangle.
Args:
batch_size: tf.Tensor (dtype=tf.int32) or int representing the minibatch
size of the state tensors.
action_size: (int) Size of continuous actio space.
num_samples: (int) Number of action samples for each minibatch element.
minval: (float) Minimum value for each action dimension.
maxval: (float) Maximum value for each action dimension.
Returns:
Tensor (dtype=tf.float32) of shape (batch_size * num_samples, action_size).
"""
return tf.random_uniform(
(batch_size * num_samples, action_size), minval=minval, maxval=maxval)
def _q_tp1_notarget(q_func, state_p1, batch_size, num_samples, random_actions):
"""Evaluate target values at t+1 using online Q function (no target network).
Args:
q_func: Python function that takes in state, action, scope as input
and returns Q(state, action) and intermediate endpoints dictionary.
state_p1: Tensor (potentially any dtype) representing next .
batch_size: tf.Tensor (dtype=tf.int32) or int representing the minibatch
size of the state tensors.
num_samples: (int) Number of action samples for each minibatch element.
random_actions: tf.Tensor (dtype=tf.float32) of candidate actions.
Returns:
Tuple (q_tp1_best, end_points_tp1). See _get_q_tp1 docs for description.
"""
# Evaluate target values using the current net only.
q_tp1_using_online_net, end_points_tp1 = q_func(
state_p1, random_actions, scope='q_func', reuse=True)
q_tp1_using_online_net_2d = tf.reshape(
q_tp1_using_online_net, (batch_size, num_samples))
q_tp1_best = tf.reduce_max(q_tp1_using_online_net_2d, 1)
return q_tp1_best, end_points_tp1
def _q_tp1_normal(q_func, state_p1, batch_size, num_samples, random_actions):
"""Evaluate target values at t+1 using separate target network network.
Args:
q_func: Python function that takes in state, action, scope as input
and returns Q(state, action) and intermediate endpoints dictionary.
state_p1: Tensor (potentially any dtype) representing next .
batch_size: tf.Tensor (dtype=tf.int32) or int representing the minibatch
size of the state tensors.
num_samples: (int) Number of action samples for each minibatch element.
random_actions: tf.Tensor (dtype=tf.float32) of candidate actions.
Returns:
Tuple (q_tp1_best, end_points_tp1). See _get_q_tp1 docs for description.
"""
q_tp1_target, end_points_tp1 = q_func(
state_p1, random_actions, scope='target_q_func')
q_tp1_target_2d = tf.reshape(q_tp1_target, (batch_size, num_samples))
q_tp1_best = tf.reduce_max(q_tp1_target_2d, 1)
return q_tp1_best, end_points_tp1
def _q_tp1_doubleq(q_func,
state_p1,
batch_size,
action_size,
num_samples,
random_actions):
"""Q(s_p1, a_p1) via Double Q Learning with stochastic sampling.
Args:
q_func: Python function that takes in state, action, scope as input
and returns Q(state, action) and intermediate endpoints dictionary.
state_p1: Tensor (potentially any dtype) representing next .
batch_size: tf.Tensor (dtype=tf.int32) or int representing the minibatch
size of the state tensors.
action_size: (int) Size of continuous actio space.
num_samples: (int) Number of action samples for each minibatch element.
random_actions: tf.Tensor (dtype=tf.float32) of candidate actions.
Returns:
Tuple (q_tp1_best, end_points_tp1). See _get_q_tp1 docs for description.
"""
# Target Q values at t+1, for action selected by online net.
q_tp1_using_online_net, end_points_tp1 = q_func(
state_p1, random_actions, scope='q_func', reuse=True)
q_tp1_using_online_net_2d = tf.reshape(
q_tp1_using_online_net, (batch_size, num_samples))
q_tp1_indices_using_online_net = tf.argmax(q_tp1_using_online_net_2d, 1)
random_actions = tf.reshape(
random_actions, (batch_size, num_samples, action_size))
batch_indices = tf.cast(tf.range(batch_size), tf.int64)
indices = tf.stack([batch_indices, q_tp1_indices_using_online_net], axis=1)
# For each batch item, slice into the num_samples,
# action subarray using the corresponding to yield the chosen action.
q_tp1_best_action = tf.gather_nd(random_actions, indices)
q_tp1_best, end_points_tp1 = q_func(
state_p1, q_tp1_best_action, scope='target_q_func')
return q_tp1_best, end_points_tp1
def _get_q_tp1(q_func,
state_p1,
batch_size,
action_size,
num_samples,
random_sample_fn,
target_network_type):
"""Computes non-discounted Bellman target Q(s_p1, a_p1).
Args:
q_func: Python function that takes in state, action, scope as input
and returns Q(state, action) and intermediate endpoints dictionary.
state_p1: Tensor (potentially any dtype) representing next .
batch_size: tf.Tensor (dtype=tf.int32) or int representing the minibatch
size of the state tensors.
action_size: (int) Size of continuous action space.
num_samples: (int) Number of action samples for each minibatch element.
random_sample_fn: See random_continuous_q_graph.
target_network_type: See random_continuous_q_graph.
Returns:
Tuple (q_tp1_best, end_points_tp1) where q_tp1_best is a tensor of best
next-actions as computed by a greedy stochastic policy for each minibatch
element in state_p1. end_points_tp1 is any auxiliary ouputs computed via
q_func.
"""
random_actions = random_sample_fn(batch_size, action_size, num_samples)
if target_network_type == DQNTarget.notarget:
q_tp1_best, end_points_tp1 = _q_tp1_notarget(
q_func, state_p1, batch_size, num_samples, random_actions)
elif target_network_type == DQNTarget.normal:
q_tp1_best, end_points_tp1 = _q_tp1_normal(
q_func, state_p1, batch_size, num_samples, random_actions)
elif target_network_type == DQNTarget.doubleq:
q_tp1_best, end_points_tp1 = _q_tp1_doubleq(
q_func, state_p1, batch_size, action_size, num_samples, random_actions)
else:
logging.error('Invalid target_network_mode %s', target_network_type)
return q_tp1_best, end_points_tp1
@gin.configurable
def random_continuous_q_graph(q_func,
transition,
random_sample_fn=random_sample_box,
num_samples=10,
target_network_type=DQNTarget.normal,
gamma=1.0,
loss_fn=tf.losses.huber_loss,
extra_callback=None,
log_input_image=True):
"""Construct loss/summary graph for continuous Q-Learning via sampling.
This Q-function loss implementation is derived from OpenAI baselines, extended
to work in the continuous case. This function supports batch sizes whose value
is only known at runtime.
Args:
q_func: Python function that takes in state, action, scope as input
and returns Q(state, action) and intermediate endpoints dictionary.
transition: SARSTransition namedtuple.
random_sample_fn: Function that samples actions for Bellman Target
maximization.
num_samples: For each state, how many actions to randomly sample in order
to compute approximate max over Q values.
target_network_type: Option to use Q Learning without target network, Q
Learning with a target network (default), or Double-Q Learning with a
target network.
gamma: Discount factor.
loss_fn: Function that computes the td_loss tensor. Takes as arguments
(target value tensor, predicted value tensor).
extra_callback: Optional function that takes in (transition, end_points_t,
end_points_tp1) and adds additional TF graph elements.
log_input_image: If True, creates an image summary of the first element of
the state tuple (assumed to be an image tensor).
Returns:
A tuple (loss, summaries) where loss is a scalar loss tensor to minimize,
summaries are TensorFlow summaries.
"""
state = transition.state
action = transition.action
state_p1 = transition.state_p1
reward = transition.reward
done = transition.done
q_t_selected, end_points_t = q_func(state, action, scope='q_func')
if log_input_image:
tf.summary.image('input_image', state[0])
if gamma != 0:
action_size = action.get_shape().as_list()[1]
batch_size = tf.shape(done)[0]
q_tp1_best, end_points_tp1 = _get_q_tp1(
q_func, state_p1, batch_size, action_size, num_samples,
random_sample_fn, target_network_type)
# Bellman eq is Q(s,a) = r + max_{a_p1} Q(s_p1, a_p1)
# Q(s_T, a_T) is regressed to r, and the max_{a_p1} Q(s_p1, a_p1)
# term is masked to zero.
q_tp1_best_masked = (1.0 - done) * q_tp1_best
# compute RHS of bellman equation
q_t_selected_target = tf.stop_gradient(reward + gamma * q_tp1_best_masked)
else:
# Supervised Target.
end_points_tp1 = None
q_t_selected_target = reward
td_error = q_t_selected - q_t_selected_target
if extra_callback is not None:
extra_callback(transition, end_points_t, end_points_tp1)
tf.summary.histogram('td_error', td_error)
tf.summary.histogram('q_t_selected', q_t_selected)
tf.summary.histogram('q_t_selected_target', q_t_selected_target)
tf.summary.scalar('mean_q_t_selected', tf.reduce_mean(q_t_selected))
td_loss = loss_fn(q_t_selected_target, q_t_selected)
tf.summary.scalar('td_loss', td_loss)
reg_loss = tf.losses.get_regularization_loss()
tf.summary.scalar('reg_loss', reg_loss)
loss = tf.losses.get_total_loss()
tf.summary.scalar('total_loss', loss)
summaries = tf.summary.merge_all()
return loss, summaries
def _get_tau_var(tau, tau_curriculum_steps):
"""Variable which increases linearly from 0 to tau over so many steps."""
if tau_curriculum_steps > 0:
tau_var = tf.get_variable('tau', [],
initializer=tf.constant_initializer(0.0),
trainable=False)
tau_var = tau_var.assign(
tf.minimum(float(tau), tau_var + float(tau) / tau_curriculum_steps))
else:
tau_var = tf.get_variable('tau', [],
initializer=tf.constant_initializer(float(tau)),
trainable=False)
return tau_var
def _get_pcl_values(q_func, not_pad, state, tstep, action,
random_sample_fn, num_samples, target_network_type):
"""Computes Q- and V-values for batch of episodes."""
# get dimensions of input
batch_size = tf.shape(state)[0]
episode_length = tf.shape(state)[1]
img_height = state.get_shape().as_list()[2]
img_width = state.get_shape().as_list()[3]
img_channels = state.get_shape().as_list()[4]
action_size = action.get_shape().as_list()[2]
# flatten input so each row corresponds to single transition
flattened_state = tf.reshape(state, [batch_size * episode_length,
img_height, img_width, img_channels])
flattened_tstep = tf.reshape(tstep, [batch_size * episode_length])
flattened_action = tf.reshape(action,
[batch_size * episode_length, action_size])
flattened_q_t, end_points_q_t = q_func(
(flattened_state, flattened_tstep), flattened_action, scope='q_func')
flattened_v_t, end_points_v_t = _get_q_tp1(
q_func, (flattened_state, flattened_tstep),
batch_size * episode_length, action_size, num_samples,
random_sample_fn, target_network_type)
# reshape to correspond to original input
q_t = not_pad * tf.reshape(flattened_q_t, [batch_size, episode_length])
v_t = not_pad * tf.reshape(flattened_v_t, [batch_size, episode_length])
v_t = tf.stop_gradient(v_t)
return q_t, v_t, end_points_q_t, end_points_v_t
@gin.configurable
def random_continuous_pcl_graph(q_func,
transition,
random_sample_fn=random_sample_box,
num_samples=10,
target_network_type=None,
gamma=1.0,
rollout=20,
loss_fn=tf.losses.huber_loss,
tau=1.0,
tau_curriculum_steps=0,
stop_gradient_on_adv=False,
extra_callback=None):
"""Construct loss/summary graph for continuous PCL via sampling.
This is an implementation of "Corrected MC", a specific variant of PCL.
See https://arxiv.org/abs/1802.10264
Args:
q_func: Python function that takes in state, action, scope as input
and returns Q(state, action) and intermediate endpoints dictionary.
transition: SARSTransition namedtuple containing a batch of episodes.
random_sample_fn: Function that samples actions for Bellman Target
maximization.
num_samples: For each state, how many actions to randomly sample in order
to compute approximate max over Q values.
target_network_type: Option to use Q Learning without target network, Q
Learning with a target network (default), or Double-Q Learning with a
target network.
gamma: Float discount factor.
rollout: Integer rollout parameter. When rollout = 1 we recover Q-learning.
loss_fn: Function that computes the td_loss tensor. Takes as arguments
(target value tensor, predicted value tensor).
tau: Coefficient on correction terms (i.e. on advantages).
tau_curriculum_steps: Increase tau linearly from 0 over this many steps.
stop_gradient_on_adv: Whether to allow training of q-values to targets in
the past.
extra_callback: Optional function that takes in (transition, end_points_t,
end_points_tp1) and adds additional TF graph elements.
Returns:
A tuple (loss, summaries) where loss is a scalar loss tensor to minimize,
summaries are TensorFlow summaries.
"""
if target_network_type is None:
target_network_type = DQNTarget.normal
tau_var = _get_tau_var(tau, tau_curriculum_steps)
state, tstep = transition.state
action = transition.action
reward = transition.reward
done = transition.done
not_pad = tf.to_float(tf.equal(tf.cumsum(done, axis=1, exclusive=True), 0.0))
reward *= not_pad
q_t, v_t, end_points_q_t, end_points_v_t = _get_pcl_values(
q_func, not_pad, state, tstep, action,
random_sample_fn, num_samples, target_network_type)
discounted_sum_rewards = discounted_future_sum(reward, gamma, rollout)
advantage = q_t - v_t # equivalent to tau * log_pi in PCL
if stop_gradient_on_adv:
advantage = tf.stop_gradient(advantage)
discounted_sum_adv = discounted_future_sum(
shift_values(advantage, gamma, 1), gamma, rollout - 1)
last_v = shift_values(v_t, gamma, rollout)
# values we regress on
pcl_values = q_t
# targets we regress to
pcl_targets = -tau_var * discounted_sum_adv + discounted_sum_rewards + last_v
# error is discrepancy between values and targets
pcl_error = pcl_values - pcl_targets
if extra_callback:
extra_callback(transition, end_points_q_t, end_points_v_t)
tf.summary.histogram('pcl_error', pcl_error)
tf.summary.histogram('q_t', q_t)
tf.summary.histogram('v_t', v_t)
tf.summary.scalar('mean_q_t', tf.reduce_mean(q_t))
pcl_loss = loss_fn(pcl_values, pcl_targets, weights=not_pad)
tf.summary.scalar('pcl_loss', pcl_loss)
reg_loss = tf.losses.get_regularization_loss()
tf.summary.scalar('reg_loss', reg_loss)
loss = tf.losses.get_total_loss()
tf.summary.scalar('total_loss', loss)
summaries = tf.summary.merge_all()
return loss, summaries
def shift_values(values, discount, rollout):
"""Shift values up by some amount of time.
Args:
values: Tensor of shape [batch_size, time].
discount: Scalar (float) representing discount factor.
rollout: Amount (int) to shift values in time by.
Returns:
Tensor of shape [batch_size, time] with values shifted.
"""
final_values = tf.zeros_like(values[:, 0])
roll_range = tf.cumsum(tf.ones_like(values[:, :rollout]), 0,
exclusive=True, reverse=True)
final_pad = tf.expand_dims(final_values, 1) * discount ** roll_range
return tf.concat([discount ** rollout * values[:, rollout:],
final_pad], 1)
def discounted_future_sum(values, discount, rollout):
"""Discounted future sum of values.
Args:
values: A tensor of shape [batch_size, episode_length].
discount: Scalar discount factor.
rollout: Number of steps to compute sum.
Returns:
Tensor of same shape as values.
"""
if not rollout:
return tf.zeros_like(values)
discount_filter = tf.reshape(
discount ** tf.range(float(rollout)), [-1, 1, 1])
expanded_values = tf.concat(
[values, tf.zeros([tf.shape(values)[0], rollout - 1])], 1)
conv_values = tf.squeeze(tf.nn.conv1d(
tf.expand_dims(expanded_values, -1), discount_filter,
stride=1, padding='VALID'), -1)
return conv_values