Skip to content

Commit

Permalink
Update SAC.py
Browse files Browse the repository at this point in the history
  • Loading branch information
huanghanchi authored Jul 30, 2022
1 parent fbfc2ad commit f88baf5
Showing 1 changed file with 29 additions and 54 deletions.
83 changes: 29 additions & 54 deletions SAC.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,26 @@
import gym
envs=[]
for env_name in ['YarsRevenge', 'Jamesbond', 'FishingDerby', 'Venture',
'DoubleDunk', 'Kangaroo', 'IceHockey', 'ChopperCommand', 'Krull',
'Robotank', 'BankHeist', 'RoadRunner', 'Hero', 'Boxing',
'Seaquest', 'PrivateEye', 'StarGunner', 'Riverraid',
'Zaxxon', 'Tennis', 'BattleZone',
'MontezumaRevenge', 'Frostbite', 'Gravitar',
'Defender', 'Pitfall', 'Solaris', 'Berzerk',
'Centipede'][:10]:
env=gym.make(env_name)
envs.append(env)

import os
import sys
import yaml
import argparse
from datetime import datetime
from torch.optim import Adam
from abc import ABC, abstractmethod
import os
import numpy as np
import matplotlib.pyplot as plt
import constopt
from constopt.constraints import LinfBall
from constopt.stochastic import PGD, PGDMadry, FrankWolfe, MomentumFrankWolfe
import torch
from torch.optim import Adam
from torch.autograd import Variable
import torch.nn as nn
from torch.nn import functional as F
from torch.distributions import Categorical
import torch.nn.utils as utils
from torch.utils.tensorboard import SummaryWriter
from scipy.stats import rankdata
from collections import deque
sys.path.insert(0,r'constopt-pytorch/')

class BaseAgent(ABC):

Expand Down Expand Up @@ -279,31 +279,23 @@ def __del__(self):
self.test_env.close()
self.writer.close()

import torch
import torch.nn as nn
from torch.nn import functional as F
from torch.distributions import Categorical

def initialize_weights_he(m):
if isinstance(m, nn.Linear) or isinstance(m, nn.Conv2d):
torch.nn.init.kaiming_uniform_(m.weight)
if m.bias is not None:
torch.nn.init.constant_(m.bias, 0)


class Flatten(nn.Module):
def forward(self, x):
return x.view(x.size(0), -1)


class BaseNetwork(nn.Module):
def save(self, path):
torch.save(self.state_dict(), path)

def load(self, path):
self.load_state_dict(torch.load(path))


class QNetwork(BaseNetwork):

def __init__(self, num_channels, num_actions, shared=False,
Expand Down Expand Up @@ -358,7 +350,6 @@ def forward(self, states):
q2 = self.Q2(states)
return q1, q2


class CateoricalPolicy(BaseNetwork):

def __init__(self, num_channels, num_actions, shared=False):
Expand Down Expand Up @@ -544,11 +535,6 @@ def save_models(self, save_dir):
self.online_critic.save(str(index)+'online_criticsac_newatari.pth')
self.target_critic.save(str(index)+'target_criticsac_newatari.pth')

from collections import deque
import numpy as np
import torch


class MultiStepBuff:

def __init__(self, maxlen=3):
Expand Down Expand Up @@ -588,7 +574,6 @@ def is_full(self):
def __len__(self):
return len(self.rewards)


class LazyMemory(dict):

def __init__(self, capacity, state_shape, device):
Expand Down Expand Up @@ -659,7 +644,6 @@ def _sample(self, indices, batch_size):
def __len__(self):
return self._n


class LazyMultiStepMemory(LazyMemory):

def __init__(self, capacity, state_shape, device, gamma=0.99,
Expand Down Expand Up @@ -687,10 +671,6 @@ def append(self, state, action, reward, next_state, done):
else:
self._append(state, action, reward, next_state, done)

from collections import deque
import numpy as np


def update_params(optim, loss,inner_rnd, retain_graph=False):
optim.zero_grad()
w=[]
Expand All @@ -699,7 +679,7 @@ def update_params(optim, loss,inner_rnd, retain_graph=False):
pre=loss
with torch.autograd.set_detect_anomaly(True):
pre.backward(retain_graph=retain_graph)
#nn.utils.clip_grad_norm_(model.parameters(), 30)
# nn.utils.clip_grad_norm_(model.parameters(), 30)
optim.step()

def disable_gradients(network):
Expand All @@ -718,18 +698,6 @@ def append(self, x):
def get(self):
return np.mean(self.stats)

import torch
import matplotlib.pyplot as plt
import sys
sys.path.insert(0,r'constopt-pytorch/')
import constopt
from constopt.constraints import LinfBall
from constopt.stochastic import PGD, PGDMadry, FrankWolfe, MomentumFrankWolfe
import torch
from torch.autograd import Variable
import torch.nn.utils as utils
from scipy.stats import rankdata

def loss(rloss,w,B,mu=0.2,lamb=[0.01,0.01,0.01]):
return torch.tensor([1+mu*(np.linalg.norm(B[t],ord=1)-np.linalg.norm(B[t][t],ord=1)) for t in range(len(envs))]).dot(rloss)+lamb[0]*sum([sum([sum([torch.norm(w[i][t]-sum([B.T[t][j]*w[i][j] for j in range(len(envs))]),p=2)**2]) for i in range(2)]) for t in range(len(envs))])

Expand All @@ -745,21 +713,28 @@ def __init__(self):
args=parser()
with open(args.config) as f:
config = yaml.load(f, Loader=yaml.SafeLoader)

envs=[]
for env_name in ['YarsRevenge', 'Jamesbond', 'FishingDerby', 'Venture',
'DoubleDunk', 'Kangaroo', 'IceHockey', 'ChopperCommand', 'Krull',
'Robotank', 'BankHeist', 'RoadRunner', 'Hero', 'Boxing',
'Seaquest', 'PrivateEye', 'StarGunner', 'Riverraid',
'Zaxxon', 'Tennis', 'BattleZone',
'MontezumaRevenge', 'Frostbite', 'Gravitar',
'Defender', 'Pitfall', 'Solaris', 'Berzerk',
'Centipede'][:10]:
env=gym.make(env_name)
envs.append(env)

rloss=[0.0 for i in range(len(envs))]
rewardsRec=[[] for i in range(len(envs))]
rewardsRec_nor=[[0] for i in range(len(envs))]
succeessRec=[[] for i in range(len(envs))]
try:
rewardsRec=np.load('sac_newatari_rewardsRec.npy',allow_pickle=True)
succeessRec=np.load('sac_newatari_succeessRec.npy',allow_pickle=True)
except:
pass

agents=[]
for index in range(len(envs)):
# Create environments.
env =envs[index]
env = envs[index]
test_env = envs[index]

# Specify the directory to log.
Expand All @@ -779,4 +754,4 @@ def __init__(self):
for i_episode in range(10000):
for index in range(len(envs)):
rnd=i_episode
agents[index].run(rnd)
agents[index].run(rnd)

0 comments on commit f88baf5

Please sign in to comment.