diff --git a/main.py b/main.py index ccb8e49..29e06f2 100644 --- a/main.py +++ b/main.py @@ -20,6 +20,9 @@ parser.add_argument('--clip', type=float, default=.1, help='probability ratio clipping range') parser.add_argument('--gamma', type=float, default=.99, help='discount factor') parser.add_argument('--lambd', type=float, default=.95, help='GAE lambda parameter') +parser.add_argument('--value-coef', type=float, default=1., help='value loss coeffecient') +parser.add_argument('--entropy-coef', type=float, default=.01, help='entropy loss coeffecient') +parser.add_argument('--max-grad-norm', type=float, default=.5, help='grad norm to clip at') parser.add_argument('--seed', type=int, default=0, help='random seed') args = parser.parse_args() @@ -36,5 +39,7 @@ algorithm = PPO(policy, venv, optimizer, clip=args.clip, gamma=args.gamma, lambd=args.lambd, worker_steps=args.worker_steps, sequence_steps=args.sequence_steps, - batch_steps=args.batch_steps) + batch_steps=args.batch_steps, + value_coef=args.value_coef, entropy_coef=args.entropy_coef, + max_grad_norm=args.max_grad_norm) algorithm.run(args.total_steps) diff --git a/ppo.py b/ppo.py index 917cd42..f91dbd8 100644 --- a/ppo.py +++ b/ppo.py @@ -1,3 +1,4 @@ +import random import copy import torch import torch.nn as nn @@ -9,7 +10,8 @@ class PPO: def __init__(self, policy, venv, optimizer, clip=.1, gamma=.99, lambd=.95, - worker_steps=128, sequence_steps=32, batch_steps=256, opt_epochs=3): + worker_steps=128, sequence_steps=32, batch_steps=256, + opt_epochs=3, value_coef=1., entropy_coef=.01, max_grad_norm=.5): """ Proximal Policy Optimization algorithm class Evaluates a policy over a vectorized environment and @@ -37,11 +39,15 @@ def __init__(self, policy, venv, optimizer, clip=.1, gamma=.99, lambd=.95, self.worker_steps = worker_steps self.sequence_steps = sequence_steps self.batch_steps = batch_steps - self.opt_epochs = opt_epochs - self.objective = PPOObjective(clip) + self.opt_epochs = opt_epochs self.gamma = gamma self.lambd = lambd + self.value_coef = value_coef + self.entropy_coef = entropy_coef + self.max_grad_norm = max_grad_norm + + self.objective = PPOObjective(clip) self.last_ob = self.venv.reset() @@ -56,20 +62,68 @@ def run(self, total_steps): N = self.num_workers T = self.worker_steps E = self.opt_epochs + A = self.venv.action_space.n while taken_steps < total_steps: obs, rewards, masks, actions, steps = self.interact() + ob_shape = obs.size()[2:] # compute advantages, returns with GAE # TEMP upgrade to support recurrence - obs = obs.view(((T + 1) * N,) + obs.size()[2:]) - obs = Variable(obs) - _, values = self.policy(obs) + obs_ = obs.view(((T + 1) * N,) + ob_shape) + obs_ = Variable(obs_) + _, values = self.policy(obs_) values = values.view(T + 1, N, 1) advantages, returns = gae(rewards, masks, values, self.gamma, self.lambd) + self.policy_old.load_state_dict(self.policy.state_dict()) for e in range(E): - raise NotImplementedError + self.policy.zero_grad() + + S = self.sequence_steps + B_total = steps // self.sequence_steps + B = self.batch_steps // self.sequence_steps + + batch = random.sample(range(B_total), B) + + b_obs = Variable(obs[:T].view((S, B_total) + ob_shape)[:, batch]) + b_rewards = Variable(rewards.view(S, B_total, 1)[:, batch]) + b_masks = Variable(masks.view(S, B_total, 1)[:, batch]) + b_actions = Variable(actions.view(S, B_total, 1)[:, batch]) + b_advantages = Variable(advantages.view(S, B_total, 1)[:, batch]) + b_returns = Variable(returns.view(S, B_total, 1)[:, batch]) + + # eval policy and old policy + b_pis, b_vs, b_pi_olds, b_v_olds = [], [], [], [] + for s in range(S): + pi, v = self.policy(b_obs[s]) + b_pis.append(pi) + b_vs.append(v) + pi_old, v_old = self.policy_old(b_obs[s]) + b_pi_olds.append(pi_old) + b_v_olds.append(v_old) + cat_fn = lambda l: torch.cat([e.unsqueeze(0) for e in l]) + condensed = [cat_fn(item) for item in [b_pis, b_vs, b_pi_olds, b_v_olds]] + b_pis, b_vs, b_pi_olds, b_v_olds = tuple(condensed) + + losses = self.objective(b_pis.view(S * B, A), + b_vs.view(S * B, 1), + b_pi_olds.view(S * B, A).detach(), + b_v_olds.view(S * B, 1).detach(), + b_actions.view(S * B, 1), + b_advantages.view(S * B, 1), + b_returns.view(S * B, 1)) + policy_loss, value_loss, entropy_loss = losses + + loss = policy_loss + value_loss * self.value_coef + entropy_loss * self.entropy_coef + + print('policy loss: %s' % str(policy_loss)) + print('value loss: %s' % str(value_loss)) + print('entropy loss: %s' % str(entropy_loss)) + + loss.backward() + torch.nn.utils.clip_grad_norm(self.policy.parameters(), self.max_grad_norm) + self.optimizer.step() taken_steps += steps @@ -77,7 +131,7 @@ def interact(self): """ Interacts with the environment Returns: - obs (FloatTensor): observations shaped [T + 1 x N x ...] + obs (ArgumentDefaultsHelpFormatternsor): observations shaped [T + 1 x N x ...] rewards (FloatTensor): rewards shaped [T x N x 1] masks (FloatTensor): continuation masks shaped [T x N x 1] zero at done timesteps, one otherwise @@ -118,6 +172,8 @@ def interact(self): class PPOObjective(nn.Module): def __init__(self, clip): + super().__init__() + self.clip = clip def forward(self, pi, v, pi_old, v_old, action, advantage, returns): @@ -154,7 +210,7 @@ def forward(self, pi, v, pi_old, v_old, action, advantage, returns): surr2 = torch.clamp(ratio, min=1. - self.clip, max=1. + self.clip) * advantage policy_loss = -torch.min(surr1, surr2).mean() - value_loss = (.5 * (values - returns) ** 2.).mean() + value_loss = (.5 * (v - returns) ** 2.).mean() entropy_loss = (prob * log_prob).sum(1).mean() return policy_loss, value_loss, entropy_loss