Skip to content

Commit

Permalink
Update distral.py
Browse files Browse the repository at this point in the history
  • Loading branch information
huanghanchi authored Jul 30, 2022
1 parent 98cc780 commit fbfc2ad
Showing 1 changed file with 39 additions and 54 deletions.
93 changes: 39 additions & 54 deletions distral.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,18 @@
import metaworld
import random

ml10 = metaworld.ML10() # Construct the benchmark, sampling tasks

training_envs = []
for name, env_cls in ml10.train_classes.items():
env = env_cls()
task = random.choice([task for task in ml10.train_tasks
if task.env_name == name])
env.set_task(task)
training_envs.append(env)

for env in training_envs:
obs = env.reset() # Reset environment
a = env.action_space.sample() # Sample an action
obs, reward, done, info = env.step(a) # Step the environoment with the sampled random action

envs=training_envs
num_envs=len(envs)


from torch.distributions.normal import Normal
import math

class parser:
def __init__(self):
self.gamma=0.99

self.alpha=0.9

self.beta=.5

self.seed=543

self.render=False

self.log_interval=10

self.envs=envs

args=parser()

pi = Variable(torch.FloatTensor([math.pi]))

def normalized_columns_initializer(weights, std=1.0):
out = torch.randn(weights.size())
out *= std / torch.sqrt(out.pow(2).sum(1).expand_as(out))
Expand Down Expand Up @@ -72,21 +43,30 @@ def __init__(self):
# not shared layers

self.mu_heads = nn.ModuleList ( [ nn.Sequential(
nn.Linear(12, 64),
nn.Tanh(),
nn.Linear(64,4)
nn.Linear(12, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, 16),
nn.ReLU(),
nn.Linear(16,4)
) for i in range(self.num_envs+1) ] )
self.sigma2_heads =nn.ModuleList ( [ nn.Sequential(
nn.Linear(12, 64),
nn.Tanh(),
nn.Linear(64,4)
nn.Linear(12, 128),
nn.ReLU(),
nn.Linear(128, 64),
nn.ReLU(),
nn.Linear(64, 32),
nn.ReLU(),
nn.Linear(32, 16),
nn.ReLU(),
nn.Linear(16,4)
) for i in range(self.num_envs+1) ] )
self.value_heads = nn.ModuleList([nn.Linear(16, 4) for i in range(self.num_envs)])

self.apply(weights_init)
# +1 for the distilled policy


# initialize lists for holding run information
self.div = [[] for i in range(num_envs)]
self.saved_actions = [[] for i in range(self.num_envs)]
Expand All @@ -102,22 +82,11 @@ def forward(self, y, index):
sigma2 = self.sigma2_heads[index](x)
sigma = F.softplus(sigma2)
value = self.value_heads[index](x)


mu_dist = F.softmax(self.mu_heads[-1](x),dim=-1)[0]
sigma2_dist = self.sigma2_heads[-1](x)
sigma_dist = F.softplus(sigma2_dist)
return mu, sigma, value, mu_dist, sigma_dist


# for debugging
test = False
model = Policy()
# learning rate - might be useful to change
optimizer = optim.Adam(model.parameters(), lr=1e-3)
eps = np.finfo(np.float32).eps.item()
SavedAction = namedtuple('SavedAction', ['log_prob', 'value'])

def select_action(state, index):
'''given a state, this function chooses the action to take
arguments: state - observation matrix specifying the current model state
Expand Down Expand Up @@ -148,7 +117,6 @@ def select_action(state, index):
# model.div[index].append(torch.div(tsigma.sqrt(),sigma.sqrt()).log() + torch.div(sigma+(tmu-mu).pow(2),tsigma*2) - 0.5)
return prob.loc


def finish_episode():
policy_losses = []
value_losses = []
Expand Down Expand Up @@ -222,17 +190,34 @@ def finish_episode():
model.entropies = []
model.rewards = [[] for i in range(model.num_envs)]

ml10 = metaworld.MT10() # Construct the benchmark, sampling tasks
envs = []
for name, env_cls in ml10.train_classes.items():
env = env_cls()
task = random.choice([task for task in ml10.train_tasks
if task.env_name == name])
env.set_task(task)
envs.append(env)

pi = Variable(torch.FloatTensor([math.pi]))
eps = np.finfo(np.float32).eps.item()
SavedAction = namedtuple('SavedAction', ['log_prob', 'value'])
test = False
trained = False
args=parser()
model = Policy()
# learning rate - might be useful to change
optimizer = optim.Adam(model.parameters(), lr=1e-3)

running_reward = 10
run_reward = np.array([10 for i in range(num_envs)])
roll_length = np.array([0 for i in range(num_envs)])
trained = False
trained_envs = np.array([False for i in range(num_envs)])
rewardsRec=[[] for i in range(num_envs)]
for i_episode in range(6000):
p = np.random.random()
# roll = np.random.randint(2)
length = 0

for index, env in enumerate(envs):
# Train each environment simultaneously with the distilled policy
state = env.reset()
Expand All @@ -253,4 +238,4 @@ def finish_episode():
break
np.save('distral_rewardsRec.npy',rewardsRec)
torch.save(model.state_dict(), 'distral_params.pkl')
finish_episode()
finish_episode()

0 comments on commit fbfc2ad

Please sign in to comment.