Skip to content

Commit

Permalink
PPO objectives
Browse files Browse the repository at this point in the history
  • Loading branch information
lnpalmer committed Nov 24, 2017
1 parent 3917437 commit 9ee4eb5
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 3 deletions.
1 change: 1 addition & 0 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
parser.add_argument('env_id', type=str, help='Gym environment id')
parser.add_argument('--arch', type=str, default='cnn', help='policy architecture, {lstm, cnn}')
parser.add_argument('--num-workers', type=int, default=8, help='number of parallel actors')
parser.add_argument('--opt-epochs', type=int, default=3, help='optimization epochs between environment interaction')
parser.add_argument('--total-steps', type=int, default=int(10e6), help='total number of environment steps to take')
parser.add_argument('--worker-steps', type=int, default=128, help='steps per worker between optimization rounds')
parser.add_argument('--sequence-steps', type=int, default=32, help='steps per sequence (for backprop through time)')
Expand Down
49 changes: 46 additions & 3 deletions ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@

class PPO:
def __init__(self, policy, venv, optimizer, clip=.1, gamma=.99, lambd=.95,
worker_steps=128, sequence_steps=32, batch_steps=256):
worker_steps=128, sequence_steps=32, batch_steps=256, opt_epochs=3):
""" Proximal Policy Optimization algorithm class
Evaluates a policy over a vectorized environment and
optimizes over policy, value, entropy objectives.
Assumes discrete action space.
Args:
policy (nn.Module): the policy to optimize
venv (vec_env): the vectorized environment to use
Expand All @@ -35,6 +37,7 @@ def __init__(self, policy, venv, optimizer, clip=.1, gamma=.99, lambd=.95,
self.worker_steps = worker_steps
self.sequence_steps = sequence_steps
self.batch_steps = batch_steps
self.opt_epochs = opt_epochs

self.objective = PPOObjective(clip)
self.gamma = gamma
Expand All @@ -52,6 +55,7 @@ def run(self, total_steps):

N = self.num_workers
T = self.worker_steps
E = self.opt_epochs

while taken_steps < total_steps:
obs, rewards, masks, actions, steps = self.interact()
Expand All @@ -64,6 +68,9 @@ def run(self, total_steps):
values = values.view(T + 1, N, 1)
advantages, returns = gae(rewards, masks, values, self.gamma, self.lambd)

for e in range(E):
raise NotImplementedError

taken_steps += steps

def interact(self):
Expand Down Expand Up @@ -113,5 +120,41 @@ class PPOObjective(nn.Module):
def __init__(self, clip):
self.clip = clip

def forward(self, inputs):
raise NotImplementedError
def forward(self, pi, v, pi_old, v_old, action, advantage, returns):
""" Computes PPO objectives
Assumes discrete action space.
Args:
pi (Variable): discrete action logits, shaped [N x num_actions]
v (Variable): value predictions, shaped [N x 1]
pi_old (Variable): old discrete action logits, shaped [N x num_actions]
v_old (Variable): old value predictions, shaped [N x 1]
action (Variable): discrete actions, shaped [N x 1]
advantage (Variable): action advantages, shaped [N x 1]
returns (Variable): discounted returns, shaped [N x 1]
Returns:
policy_loss (Variable): policy surrogate loss, shaped [1]
value_loss (Variable): value loss, shaped [1]
entropy_loss (Variable): entropy loss, shaped [1]
"""
prob = Fnn.softmax(pi)
log_prob = Fnn.log_softmax(pi)
action_prob = prob.gather(1, action)

prob_old = Fnn.softmax(pi_old)
action_prob_old = prob_old.gather(1, action)

ratio = action_prob / (action_prob_old + 1e-10)

advantage = (advantage - advantage.mean()) / (advantage.std() + 1e-5)

surr1 = ratio * advantage
surr2 = torch.clamp(ratio, min=1. - self.clip, max=1. + self.clip) * advantage

policy_loss = -torch.min(surr1, surr2).mean()
value_loss = (.5 * (values - returns) ** 2.).mean()
entropy_loss = (prob * log_prob).sum(1).mean()

return policy_loss, value_loss, entropy_loss

0 comments on commit 9ee4eb5

Please sign in to comment.