-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
dataset and agent initialization generalized
- Loading branch information
1 parent
58f772b
commit f5fafb3
Showing
13 changed files
with
760 additions
and
134 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import torch | ||
import torch.utils.data as data | ||
from omegaconf import DictConfig, OmegaConf | ||
|
||
from contrastive_learning.datasets.state_dataset import StateDataset | ||
from contrastive_learning.datasets.visual_dataset import VisualDataset | ||
|
||
# Script to return dataloaders | ||
|
||
def get_dataloaders(cfg : DictConfig): | ||
# Load dataset - splitting will be done with random splitter | ||
if cfg.dataset_type == 'state': | ||
dataset = StateDataset(data_dir=cfg.data_dir) | ||
else: | ||
dataset = VisualDataset(data_dir=cfg.data_dir, frame_interval=cfg.frame_interval, video_type=cfg.video_type) | ||
|
||
train_dset_size = int(len(dataset) * cfg.train_dset_split) | ||
test_dset_size = len(dataset) - train_dset_size | ||
|
||
# Random split the train and validation datasets | ||
train_dset, test_dset = data.random_split(dataset, | ||
[train_dset_size, test_dset_size], | ||
generator=torch.Generator().manual_seed(cfg.seed)) | ||
train_sampler = data.DistributedSampler(train_dset, drop_last=True, shuffle=True) if cfg.distributed else None | ||
test_sampler = data.DistributedSampler(test_dset, drop_last=True, shuffle=False) if cfg.distributed else None # val will not be shuffled | ||
|
||
train_loader = data.DataLoader(train_dset, batch_size=cfg.batch_size, shuffle=train_sampler is None, | ||
num_workers=cfg.num_workers, sampler=train_sampler) | ||
test_loader = data.DataLoader(test_dset, batch_size=cfg.batch_size, shuffle=test_sampler is None, | ||
num_workers=cfg.num_workers, sampler=test_sampler) | ||
|
||
return train_loader, test_loader, dataset |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
import hydra | ||
from torch.nn.parallel import DistributedDataParallel as DDP | ||
|
||
|
||
# Scripts to initialize different agents - used in training scripts | ||
def init_pli(cfg, device, rank): | ||
# Initialize the model | ||
model = hydra.utils.instantiate(cfg.model, | ||
input_dim=cfg.pos_dim*2, # For dog and box | ||
action_dim=cfg.action_dim, | ||
hidden_dim=cfg.hidden_dim).to(device) | ||
model = DDP(model, device_ids=[rank], output_device=rank, broadcast_buffers=False) | ||
|
||
|
||
# Initialize the optimizer | ||
# parameters = list(encoder.parameters()) + list(trans.parameters()) | ||
optimizer = hydra.utils.instantiate(cfg.optimizer, | ||
params = model.parameters(), | ||
lr = cfg.lr, | ||
weight_decay = cfg.weight_decay) | ||
|
||
# Initialize the total agent | ||
agent = hydra.utils.instantiate(cfg.agent, | ||
model=model, | ||
optimizer=optimizer) | ||
|
||
agent.to(device) | ||
|
||
return agent | ||
|
||
def init_cpn(cfg, device, rank): | ||
# Initialize the encoder and the trans | ||
encoder = hydra.utils.instantiate(cfg.encoder).to(device) | ||
trans = hydra.utils.instantiate(cfg.trans, | ||
z_dim=cfg.z_dim, | ||
action_dim=cfg.action_dim).to(device) | ||
encoder = DDP(encoder, device_ids=[rank], output_device=rank, broadcast_buffers=False) # To fix the inplace error https://github.com/pytorch/pytorch/issues/22095 | ||
trans = DDP(trans, device_ids=[rank], output_device=rank, broadcast_buffers=False) | ||
|
||
# Initialize the optimizer | ||
parameters = list(encoder.parameters()) + list(trans.parameters()) | ||
optimizer = hydra.utils.instantiate(cfg.optimizer, | ||
params = parameters, | ||
lr = cfg.lr, | ||
weight_decay = cfg.weight_decay) | ||
|
||
# Initialize the total agent | ||
agent = hydra.utils.instantiate(cfg.agent, | ||
encoder=encoder, | ||
trans=trans, | ||
optimizer=optimizer) | ||
agent.to(device) | ||
|
||
return agent |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.