Skip to content

Commit

Permalink
not using encoder implemented
Browse files Browse the repository at this point in the history
  • Loading branch information
irmakguzey committed Jul 19, 2022
1 parent f318cf3 commit 38f97ef
Show file tree
Hide file tree
Showing 8 changed files with 202 additions and 79 deletions.
3 changes: 2 additions & 1 deletion contrastive_learning/configs/agent/sbfd.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
_target_: contrastive_learning.models.agents.sbfd.SBFD
loss_fn: infonce
loss_fn: mse
use_encoder: False
pos_encoder: ???
trans: ???
optimizer: ???
16 changes: 9 additions & 7 deletions contrastive_learning/configs/train.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ pos_encoder:
hidden_dim: ???
out_dim: ??? # set it to z_dim

# this needs to be specified manually
experiment: sbfd


seed: 42
device: cuda
Expand All @@ -35,14 +34,17 @@ save_frequency: 10 # Frequency to save the model - there will be a test in each
train_dset_split: 0.8

# Hyperparameters to be used everywhere
batch_size: 512 # Batch size should be high
lr: 1e-2
weight_decay: 1e-5
batch_size: 256
lr: 1e-4
weight_decay: 0
# z_dim: 1000 # resnet18 gives out 1000 dimensions (needed for cpn)
z_dim: 16 # Positions will be mapped to embeddings
z_dim: 8 # Positions will be mapped to embeddings
pos_dim: 8 # this is needed for state based models
hidden_dim: 64 # Expand to 64 first then to z_dim
action_dim: 2
action_dim: 8 # TODO: This should be changed at some point

# this needs to be specified manually
experiment: sbfd_ue_${agent.use_encoder}_lf_${agent.loss_fn}_bs_${batch_size}_hd_${hidden_dim}_lr_${lr}_zd_${z_dim}

distributed: true
num_workers: 4
Expand Down
32 changes: 31 additions & 1 deletion contrastive_learning/datasets/dataloaders.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import torch
import torch.utils.data as data
from omegaconf import DictConfig, OmegaConf
Expand Down Expand Up @@ -29,4 +30,33 @@ def get_dataloaders(cfg : DictConfig):
test_loader = data.DataLoader(test_dset, batch_size=cfg.batch_size, shuffle=test_sampler is None,
num_workers=cfg.num_workers, sampler=test_sampler)

return train_loader, test_loader, dataset
return train_loader, test_loader, dataset

if __name__ == "__main__":
# Start the multiprocessing to load the saved models properly
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = "29503"

torch.distributed.init_process_group(backend='gloo', rank=0, world_size=1)
torch.cuda.set_device(0)

cfg = OmegaConf.load('/home/irmak/Workspace/DAWGE/contrastive_learning/configs/train.yaml')
dset = StateDataset(
data_dir = cfg.data_dir
)

train_loader, test_loader, _ = get_dataloaders(cfg)

action_min, action_max, corner_min, corner_max = dset.calculate_mins_maxs()
print('action: [min: {}, max: {}], corners: [min: {}, max: {}]'.format(
action_min, action_max, corner_min, corner_max
))



batch = next(iter(test_loader))
pos, next_pos, action = [b for b in batch]
print('pos: {}'.format(pos))
print(dset.denormalize_corner(pos[0].detach().numpy()))
print(dset.denormalize_action(action))

34 changes: 26 additions & 8 deletions contrastive_learning/datasets/state_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def __init__(self, data_dir: str, single_dir=False) -> None:
self.pos_corners += pickle.load(f) # We need all pos_pairs in the same order when we retrieve the data

# Calculate mean and std to normalize corners and actions
self.action_mean, self.action_std, self.corner_mean, self.corner_std = self.calculate_means_stds()

# self.action_mean, self.action_std, self.corner_mean, self.corner_std = self.calculate_means_stds()
self.action_min, self.action_max, self.corner_min, self.corner_max = self.calculate_mins_maxs()

def __len__(self):
return len(self.pos_corners)
Expand All @@ -48,12 +48,12 @@ def __getitem__(self, index):
self.received_ids.append(index)
curr_pos, next_pos, action = self.pos_corners[index]

# Normalize the positions
curr_pos = torch.FloatTensor((curr_pos - self.corner_mean) / self.corner_std)
next_pos = torch.FloatTensor((next_pos - self.corner_mean) / self.corner_std)
# Normalize the positions - TODO: If this works nicely then delete mean/std approach
curr_pos = torch.FloatTensor((curr_pos - self.corner_min) / (self.corner_max - self.corner_min))
next_pos = torch.FloatTensor((next_pos - self.corner_min) / (self.corner_max - self.corner_min))

# Normalize the actions
action = torch.FloatTensor((action - self.action_mean) / self.action_std)
action = torch.FloatTensor((action - self.action_min) / (self.action_max - self.action_min))

# return box_pos, dog_pos, next_box_pos, next_dog_pos, action
return torch.flatten(curr_pos), torch.flatten(next_pos), action
Expand All @@ -77,12 +77,30 @@ def calculate_means_stds(self):

return action_mean, action_std, corner_mean, corner_std

def calculate_mins_maxs(self):
corners = np.zeros((len(self.pos_corners), 8,2))
actions = np.zeros((len(self.pos_corners), 2))
for i in range(len(self.pos_corners)):
corners[i,:] = self.pos_corners[i][0]
actions[i,0] = self.pos_corners[i][2][0]
actions[i,1] = self.pos_corners[i][2][1]

action_min, action_max = actions.min(axis=0), actions.max(axis=0)

corner_min, corner_max = corners.min(axis=(0,1)), corners.max(axis=(0,1))
corner_min, corner_max = np.expand_dims(corner_min, axis=0), np.expand_dims(corner_max, axis=0)
corner_min, corner_max = np.repeat(corner_min, 8, axis=0), np.repeat(corner_max, 8, axis=0)

return action_min, action_max, corner_min, corner_max

def denormalize_action(self, action): # action.shape: 2
return (action * self.action_std) + self.action_mean
# return (action * self.action_std) + self.action_mean
return (action * (self.action_max - self.action_min)) + self.action_min

def denormalize_corner(self, corner): # corner.shape: (16)
corner = corner.reshape((8,2))
return (corner * self.corner_std) + self.corner_mean
# return (corner * self.corner_std) + self.corner_mean
return (corner * (self.corner_max - self.corner_min)) + self.corner_min

def get_root_id(self, root):
print('root: {}, root_id[root]: {}'.format(
Expand Down
3 changes: 3 additions & 0 deletions contrastive_learning/models/agents/agent_inits.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ def init_cpn(cfg, device, rank): # This can be used for sbfd agents as well

def init_sbfd(cfg, device, rank): # This can be used for sbfd agents as well
# Initialize the encoder and the trans
if cfg.agent.use_encoder == False:
cfg.z_dim = cfg.pos_dim*2
print('z_dim: {}'.format(cfg.z_dim))
pos_encoder = hydra.utils.instantiate(cfg.pos_encoder,
input_dim=cfg.pos_dim*2,
hidden_dim=cfg.hidden_dim,
Expand Down
49 changes: 34 additions & 15 deletions contrastive_learning/models/agents/sbfd.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@ def __init__(self,
pos_encoder, # Pos -> Embedding model is named as encoder as well
trans,
optimizer,
use_encoder,
loss_fn) -> None:

self.pos_encoder = pos_encoder
self.trans = trans
self.optimizer = optimizer
self.use_encoder = use_encoder

self.loss_type = loss_fn
if loss_fn == "infonce":
Expand Down Expand Up @@ -58,18 +60,26 @@ def train_epoch(self, train_loader):
for batch in train_loader:
self.optimizer.zero_grad()
pos, pos_next, action = [b.to(self.device) for b in batch]
if self.use_encoder:
z, z_next = self.pos_encoder(pos), self.pos_encoder(pos_next) # b x z_dim
z_delta = self.trans(z, action) # b x z_dim
z_next_predict = z + z_delta
if self.loss_type == "mse":
loss = self.loss_fn(z_next, z_next_predict)
elif self.loss_type == "infonce":
loss = self.loss_fn(z, z_next, z_next_predict)
else:
pos_delta = self.trans(pos, action)
pos_next_predict = pos + pos_delta
if self.loss_type == "mse":
loss = self.loss_fn(pos_next, pos_next_predict)
elif self.loss_type == "infonce":
loss = self.loss_fn(pos, pos_next, pos_next_predict)

z, z_next = self.pos_encoder(pos), self.pos_encoder(pos_next) # b x z_dim
z_next_predict = self.trans(z, action) # b x z_dim
if self.loss_type == "mse":
loss = self.loss_fn(z_next, z_next_predict)
elif self.loss_type == "infonce":
loss = self.loss_fn(z, z_next, z_next_predict) # TODO: infonce was changed so you should check this
train_loss += loss.item()

# Back prop
loss.backward()
# nn.utils.clip_grad_norm_(parameters, 20)
self.optimizer.step()

return train_loss / len(train_loader)
Expand All @@ -83,14 +93,23 @@ def test_epoch(self, test_loader):

# Test for one epoch
for batch in test_loader:
obs, obs_next, action = [b.to(self.device) for b in batch]
pos, pos_next, action = [b.to(self.device) for b in batch]
with torch.no_grad():
z, z_next = self.pos_encoder(obs), self.pos_encoder(obs_next) # b x z_dim
z_next_predict = self.trans(z, action) # b x z_dim
if self.loss_type == "mse":
loss = self.loss_fn(z_next, z_next_predict)
elif self.loss_type == "infonce":
loss = self.loss_fn(z, z_next, z_next_predict) # TODO: infonce was changed so you should check this
test_loss += loss.item()
if self.use_encoder:
z, z_next = self.pos_encoder(pos), self.pos_encoder(pos_next) # b x z_dim
z_delta = self.trans(z, action) # b x z_dim
z_next_predict = z + z_delta
if self.loss_type == "mse":
loss = self.loss_fn(z_next, z_next_predict)
elif self.loss_type == "infonce":
loss = self.loss_fn(z, z_next, z_next_predict) # TODO: infonce was changed so you should check this
else:
pos_delta = self.trans(pos, action)
pos_next_predict = pos + pos_delta
if self.loss_type == "mse":
loss = self.loss_fn(pos_next, pos_next_predict)
elif self.loss_type == "infonce":
loss = self.loss_fn(pos, pos_next, pos_next_predict)
test_loss += loss.item()

return test_loss / len(test_loader)
16 changes: 14 additions & 2 deletions contrastive_learning/models/custom_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,29 @@ class Transition(nn.Module):
def __init__(self, z_dim, action_dim):
super().__init__()

self.a_repeatition = int(action_dim / 2)
self.z_dim = z_dim
hidden_dim = 64
self.model = nn.Sequential(
nn.Linear(z_dim + action_dim, hidden_dim),
nn.ReLU(inplace=False),
nn.ReLU(),
nn.Linear(hidden_dim, hidden_dim),
nn.ReLU(inplace=False),
nn.ReLU(),
nn.Linear(hidden_dim, 2*hidden_dim),
nn.ReLU(),
nn.Linear(2*hidden_dim, 4*hidden_dim),
nn.ReLU(),
nn.Linear(4*hidden_dim, 2*hidden_dim),
nn.ReLU(),
nn.Linear(2*hidden_dim, hidden_dim),
nn.ReLU(),
nn.Linear(hidden_dim, z_dim)
)

def forward(self, z, a):
curr_a = a
for _ in range(self.a_repeatition-1):
a = torch.cat((a,curr_a), dim=-1)
x = torch.cat((z,a), dim=-1)
x = self.model(x)
return x
Expand Down
Loading

0 comments on commit 38f97ef

Please sign in to comment.