Skip to content

Commit

Permalink
td target bug fix
Browse files Browse the repository at this point in the history
  • Loading branch information
seungeunrho committed Jul 11, 2019
1 parent c6f5131 commit 899cf7a
Showing 1 changed file with 9 additions and 10 deletions.
19 changes: 9 additions & 10 deletions a3c.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,20 +55,19 @@ def train(model, rank):
s = s_prime
if done:
break

R = 0.0
R_lst = []

s_final = torch.tensor(s_prime, dtype=torch.float)
R = 0.0 if done else model.v(s_final).item()
td_target_lst = []
for reward in r_lst[::-1]:
R = gamma * R + reward
R_lst.append([R])
R_lst.reverse()
td_target_lst.append([R])
td_target_lst.reverse()

done_mask = 0.0 if done else 1.0
s_batch, a_batch, R_batch, s_final = torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst), \
torch.tensor(R_lst), torch.tensor(s_prime, dtype=torch.float)

td_target = R_batch + gamma * model.v(s_final) * done_mask
s_batch, a_batch, td_target = torch.tensor(s_lst, dtype=torch.float), torch.tensor(a_lst), \
torch.tensor(td_target_lst)
advantage = td_target - model.v(s_batch)

pi = model.pi(s_batch,softmax_dim=1)
pi_a = pi.gather(1,a_batch)
loss = -torch.log(pi_a) * advantage.detach() + F.smooth_l1_loss(model.v(s_batch), td_target.detach())
Expand Down

0 comments on commit 899cf7a

Please sign in to comment.