Skip to content

Commit

Permalink
Fix action indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaixhin authored and soumith committed Dec 5, 2017
1 parent 9faf2c6 commit 82cef44
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 6 deletions.
6 changes: 3 additions & 3 deletions reinforcement_learning/actor_critic.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def forward(self, x):
x = F.relu(self.affine1(x))
action_scores = self.action_head(x)
state_values = self.value_head(x)
return F.softmax(action_scores), state_values
return F.softmax(action_scores, dim=1), state_values


model = Policy()
Expand All @@ -59,7 +59,7 @@ def select_action(state):
m = Multinomial(probs)
action = m.sample()
model.saved_actions.append(SavedAction(m.log_prob(action), state_value))
return action.data
return action.data[0]


def finish_episode():
Expand Down Expand Up @@ -88,7 +88,7 @@ def finish_episode():
state = env.reset()
for t in range(10000): # Don't infinite loop while learning
action = select_action(state)
state, reward, done, _ = env.step(action[0, 0])
state, reward, done, _ = env.step(action)
if args.render:
env.render()
model.rewards.append(reward)
Expand Down
6 changes: 3 additions & 3 deletions reinforcement_learning/reinforce.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(self):
def forward(self, x):
x = F.relu(self.affine1(x))
action_scores = self.affine2(x)
return F.softmax(action_scores)
return F.softmax(action_scores, dim=1)


policy = Policy()
Expand All @@ -53,7 +53,7 @@ def select_action(state):
m = Multinomial(probs)
action = m.sample()
policy.saved_actions.append(m.log_prob(action))
return action.data
return action.data[0]


def finish_episode():
Expand All @@ -79,7 +79,7 @@ def finish_episode():
state = env.reset()
for t in range(10000): # Don't infinite loop while learning
action = select_action(state)
state, reward, done, _ = env.step(action[0, 0])
state, reward, done, _ = env.step(action)
if args.render:
env.render()
policy.rewards.append(reward)
Expand Down

0 comments on commit 82cef44

Please sign in to comment.