Skip to content

Commit

Permalink
complete myppo.py
Browse files Browse the repository at this point in the history
  • Loading branch information
Buzzy0423 committed Nov 14, 2023
1 parent dda55dc commit ada295a
Show file tree
Hide file tree
Showing 9 changed files with 136 additions and 210 deletions.
188 changes: 136 additions & 52 deletions myppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,10 @@
import torch.optim as optim
from torch.distributions import Categorical
import numpy as np
import imageio
import os
from PIL import Image, ImageDraw, ImageFont
import seaborn as sns
import matplotlib.pyplot as plt
from tqdm import tqdm

class PolicyNet(nn.Module):
def __init__(self, state_dim, hidden_dim, action_dim):
Expand All @@ -32,96 +32,180 @@ def forward(self, x):
return self.fc2(x)

class PPO:
def __init__(self, state_dim, hidden_dim, action_dim, lr,
lmbda, epochs, eps, gamma, mini_batch_size):
def __init__(self, state_dim, hidden_dim, action_dim, lr_actor=0.0001, lr_critic=0.001,
lmbda=0.95, epochs=3, eps=0.3, gamma=0.99, beta=0.03):
self.actor = PolicyNet(state_dim, hidden_dim, action_dim)
self.critic = ValueNet(state_dim, hidden_dim)
self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),
lr=lr)
lr=lr_actor)
self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),
lr=lr)
lr=lr_critic)
self.gamma = gamma
self.lmbda = lmbda
self.epochs = epochs
self.eps = eps
self.beta = beta

self.data = []

def input_data(self, result):
self.data.append(result)

def make_batch(self):
states, actions, rewards, next_states, dones, old_log_probs = [], [], [], [], [], []
states, actions, rewards, old_log_probs, next_states, done_mask = [], [], [], [], [], []
for result in self.data:
state, action, reward, next_state, done, old_log_prob = result
state, action, reward, old_log_prob, next_state, done = result
states.append(state)
actions.append(action)
rewards.append(reward)
actions.append([action])
rewards.append([reward])
old_log_probs.append([old_log_prob])
next_states.append(next_state)
dones.append(done)
old_log_probs.append(old_log_prob)
return states, actions, rewards, next_states, dones, old_log_probs
done_mask.append([0 if done else 1])
return states, actions, rewards, old_log_probs, next_states, done_mask



def get_action(self, state):
if isinstance(state, list):
state = np.array(state, dtype=np.float32)

# Convert the numpy array to a PyTorch tensor
state_tensor = torch.from_numpy(state).float()

# Ensure state_tensor is 2D with shape [1, num_features]
if len(state_tensor.shape) == 1:
state_tensor = state_tensor.unsqueeze(0)

probs = self.actor(state_tensor)
action_dist = Categorical(probs)
action = action_dist.sample()
log_prob = action_dist.log_prob(action)
return log_prob, action.item()

def get_value(self, state):
if isinstance(state, list):
state = np.array(state, dtype=np.float32)

# Convert the numpy array to a PyTorch tensor
state_tensor = torch.from_numpy(state).float()

# Ensure state_tensor is 2D with shape [1, num_features]
if len(state_tensor.shape) == 1:
state_tensor = state_tensor.unsqueeze(0)

return self.critic(state_tensor)

def compute_gae(self, next_value, rewards, masks, values):
values = values + [next_value]
rewards = torch.tensor(rewards, dtype=torch.float)
masks = torch.tensor(masks, dtype=torch.float)

values = torch.cat((values, next_value), dim=0)

gae = 0
returns = []
for step in reversed(range(len(rewards))):
delta = rewards[step] + self.gamma * \
values[step + 1] * masks[step] - values[step]
next_value = values[step + 1] if step < len(rewards) - 1 else 0

delta = rewards[step] + self.gamma * next_value * masks[step] - values[step]
gae = delta + self.gamma * self.lmbda * masks[step] * gae
returns.insert(0, gae + values[step])

return returns

def update(self):
def update(self, data):
for _ in range(self.epochs):
batches = self.make_batch()
for batch in batches:
states, actions, rewards, next_states, dones, old_log_probs = batch
returns = self.compute_gae(next_value, rewards, masks, values)
returns = torch.cat(returns).detach()
old_log_probs = torch.cat(old_log_probs).detach()

for _ in range(self.K_epoch):
log_probs = self.actor(states)
dist = Categorical(log_probs)
entropy = dist.entropy().mean()
values = self.critic(states)
advantage = returns - values
surr1 = torch.exp(log_probs) * advantage
surr2 = torch.exp(old_log_probs) * advantage
loss = -torch.min(surr1, surr2) + F.smooth_l1_loss(values, returns.detach())
self.actor_optimizer.zero_grad()
self.critic_optimizer.zero_grad()
loss.mean().backward()
self.actor_optimizer.step()
self.critic_optimizer.step()
for state, action, return_, advantage, old_log_prob in data:
state = torch.tensor(state, dtype=torch.float)
action = torch.tensor(action, dtype=torch.long)
old_log_prob = torch.tensor(old_log_prob, dtype=torch.float)

probs = self.actor(state.unsqueeze(0))
dist = Categorical(probs)
entropy = dist.entropy().mean()
new_log_prob = dist.log_prob(action)
ratio = torch.exp(new_log_prob - old_log_prob)

surr1 = ratio * advantage
surr2 = torch.clamp(ratio, 1-self.eps, 1+self.eps) * advantage
actor_loss = -torch.min(surr1, surr2).mean() - self.beta * entropy
critic_loss = (return_ - self.critic(state))**2
critic_loss = critic_loss.mean()

self.actor_optimizer.zero_grad()
actor_loss.backward()
self.actor_optimizer.step()

self.critic_optimizer.zero_grad()
critic_loss.backward()
self.critic_optimizer.step()

self.data = []


def train(self):

states, actions, rewards, old_log_probs, next_states, done_masks = self.make_batch()

# Compute advantages and returns
values = self.get_value(states).detach().squeeze()
next_value = self.get_value(next_states).detach().squeeze()
returns = self.compute_gae(next_value, rewards, done_masks, values)
returns = torch.tensor(returns)
returns = (returns - returns.mean()) / (returns.std() + 1e-8)
advantages = returns - values
advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)

self.update(zip(states, actions, returns, advantages, old_log_probs))


def save_frames_as_gif(frames, path='./', filename='ppo.gif'):
pil_images = [Image.fromarray(frame) for frame in frames]

pil_images[0].save(
path + filename,
save_all=True,
append_images=pil_images[1:],
optimize=False,
duration=50,
loop=0)

def main():
PPO = PPO(state_dim=4, hidden_dim = 256, action_dim=2)
env = gym.make('CartPole-v1')
ppo = PPO(state_dim=4, hidden_dim = 256, action_dim=2)
env = gym.make('CartPole-v1', render_mode='rgb_array')
score = 0.0
scores = []
for n_epi in range(1000):
p_bar = tqdm(range(400))
frames = []
for n_epi in p_bar:
done = False
state = env.reset()
state = state[0]
while not done:
for t in range(20):
action = PPO.get_action(state)
next_state, reward, done, _ = env.step(action)
PPO.put_data((state, action, reward, next_state, done))
state = next_state
score += reward
if done:
break
PPO.train_net()
log_prob, action = ppo.get_action(state)
next_state, reward, done, _, _ = env.step(action)
ppo.input_data((state, action, reward, log_prob, next_state, done))
state = next_state
score += reward
if n_epi % 100 == 0 or n_epi == 399:
frames.append(env.render())
if done:
break
ppo.train()
scores.append(score)
score = 0.0
print("# of episode :{}, avg score : {:.1f}".format(
n_epi, np.mean(scores[-10:])))
p_bar.set_postfix_str(f'epi: {n_epi}, s_avg: {np.mean(scores[-10:])}')
if n_epi % 100 == 0 or n_epi == 399:
save_frames_as_gif(frames, filename=f'ppo_{n_epi}.gif')
frames = []

env.close()

# draw plot
sns.set_theme()
plt.plot(scores)
plt.xlabel('Episode')
plt.ylabel('Score')
plt.savefig('./ppo.png', dpi=500)



Expand Down
Binary file added ppo.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added ppo_0.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added ppo_100.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added ppo_200.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added ppo_300.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added ppo_399.gif
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
113 changes: 0 additions & 113 deletions t.py

This file was deleted.

Loading

0 comments on commit ada295a

Please sign in to comment.