Skip to content

Commit

Permalink
algorithm main logic
Browse files Browse the repository at this point in the history
  • Loading branch information
lnpalmer committed Nov 24, 2017
1 parent 9ee4eb5 commit 1a71f51
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 10 deletions.
7 changes: 6 additions & 1 deletion main.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
parser.add_argument('--clip', type=float, default=.1, help='probability ratio clipping range')
parser.add_argument('--gamma', type=float, default=.99, help='discount factor')
parser.add_argument('--lambd', type=float, default=.95, help='GAE lambda parameter')
parser.add_argument('--value-coef', type=float, default=1., help='value loss coeffecient')
parser.add_argument('--entropy-coef', type=float, default=.01, help='entropy loss coeffecient')
parser.add_argument('--max-grad-norm', type=float, default=.5, help='grad norm to clip at')
parser.add_argument('--seed', type=int, default=0, help='random seed')
args = parser.parse_args()

Expand All @@ -36,5 +39,7 @@
algorithm = PPO(policy, venv, optimizer, clip=args.clip, gamma=args.gamma,
lambd=args.lambd, worker_steps=args.worker_steps,
sequence_steps=args.sequence_steps,
batch_steps=args.batch_steps)
batch_steps=args.batch_steps,
value_coef=args.value_coef, entropy_coef=args.entropy_coef,
max_grad_norm=args.max_grad_norm)
algorithm.run(args.total_steps)
74 changes: 65 additions & 9 deletions ppo.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import random
import copy
import torch
import torch.nn as nn
Expand All @@ -9,7 +10,8 @@

class PPO:
def __init__(self, policy, venv, optimizer, clip=.1, gamma=.99, lambd=.95,
worker_steps=128, sequence_steps=32, batch_steps=256, opt_epochs=3):
worker_steps=128, sequence_steps=32, batch_steps=256,
opt_epochs=3, value_coef=1., entropy_coef=.01, max_grad_norm=.5):
""" Proximal Policy Optimization algorithm class
Evaluates a policy over a vectorized environment and
Expand Down Expand Up @@ -37,11 +39,15 @@ def __init__(self, policy, venv, optimizer, clip=.1, gamma=.99, lambd=.95,
self.worker_steps = worker_steps
self.sequence_steps = sequence_steps
self.batch_steps = batch_steps
self.opt_epochs = opt_epochs

self.objective = PPOObjective(clip)
self.opt_epochs = opt_epochs
self.gamma = gamma
self.lambd = lambd
self.value_coef = value_coef
self.entropy_coef = entropy_coef
self.max_grad_norm = max_grad_norm

self.objective = PPOObjective(clip)

self.last_ob = self.venv.reset()

Expand All @@ -56,28 +62,76 @@ def run(self, total_steps):
N = self.num_workers
T = self.worker_steps
E = self.opt_epochs
A = self.venv.action_space.n

while taken_steps < total_steps:
obs, rewards, masks, actions, steps = self.interact()
ob_shape = obs.size()[2:]

# compute advantages, returns with GAE
# TEMP upgrade to support recurrence
obs = obs.view(((T + 1) * N,) + obs.size()[2:])
obs = Variable(obs)
_, values = self.policy(obs)
obs_ = obs.view(((T + 1) * N,) + ob_shape)
obs_ = Variable(obs_)
_, values = self.policy(obs_)
values = values.view(T + 1, N, 1)
advantages, returns = gae(rewards, masks, values, self.gamma, self.lambd)

self.policy_old.load_state_dict(self.policy.state_dict())
for e in range(E):
raise NotImplementedError
self.policy.zero_grad()

S = self.sequence_steps
B_total = steps // self.sequence_steps
B = self.batch_steps // self.sequence_steps

batch = random.sample(range(B_total), B)

b_obs = Variable(obs[:T].view((S, B_total) + ob_shape)[:, batch])
b_rewards = Variable(rewards.view(S, B_total, 1)[:, batch])
b_masks = Variable(masks.view(S, B_total, 1)[:, batch])
b_actions = Variable(actions.view(S, B_total, 1)[:, batch])
b_advantages = Variable(advantages.view(S, B_total, 1)[:, batch])
b_returns = Variable(returns.view(S, B_total, 1)[:, batch])

# eval policy and old policy
b_pis, b_vs, b_pi_olds, b_v_olds = [], [], [], []
for s in range(S):
pi, v = self.policy(b_obs[s])
b_pis.append(pi)
b_vs.append(v)
pi_old, v_old = self.policy_old(b_obs[s])
b_pi_olds.append(pi_old)
b_v_olds.append(v_old)
cat_fn = lambda l: torch.cat([e.unsqueeze(0) for e in l])
condensed = [cat_fn(item) for item in [b_pis, b_vs, b_pi_olds, b_v_olds]]
b_pis, b_vs, b_pi_olds, b_v_olds = tuple(condensed)

losses = self.objective(b_pis.view(S * B, A),
b_vs.view(S * B, 1),
b_pi_olds.view(S * B, A).detach(),
b_v_olds.view(S * B, 1).detach(),
b_actions.view(S * B, 1),
b_advantages.view(S * B, 1),
b_returns.view(S * B, 1))
policy_loss, value_loss, entropy_loss = losses

loss = policy_loss + value_loss * self.value_coef + entropy_loss * self.entropy_coef

print('policy loss: %s' % str(policy_loss))
print('value loss: %s' % str(value_loss))
print('entropy loss: %s' % str(entropy_loss))

loss.backward()
torch.nn.utils.clip_grad_norm(self.policy.parameters(), self.max_grad_norm)
self.optimizer.step()

taken_steps += steps

def interact(self):
""" Interacts with the environment
Returns:
obs (FloatTensor): observations shaped [T + 1 x N x ...]
obs (ArgumentDefaultsHelpFormatternsor): observations shaped [T + 1 x N x ...]
rewards (FloatTensor): rewards shaped [T x N x 1]
masks (FloatTensor): continuation masks shaped [T x N x 1]
zero at done timesteps, one otherwise
Expand Down Expand Up @@ -118,6 +172,8 @@ def interact(self):

class PPOObjective(nn.Module):
def __init__(self, clip):
super().__init__()

self.clip = clip

def forward(self, pi, v, pi_old, v_old, action, advantage, returns):
Expand Down Expand Up @@ -154,7 +210,7 @@ def forward(self, pi, v, pi_old, v_old, action, advantage, returns):
surr2 = torch.clamp(ratio, min=1. - self.clip, max=1. + self.clip) * advantage

policy_loss = -torch.min(surr1, surr2).mean()
value_loss = (.5 * (values - returns) ** 2.).mean()
value_loss = (.5 * (v - returns) ** 2.).mean()
entropy_loss = (prob * log_prob).sum(1).mean()

return policy_loss, value_loss, entropy_loss

0 comments on commit 1a71f51

Please sign in to comment.