Skip to content

Commit

Permalink
add augmentation for diffusion
Browse files Browse the repository at this point in the history
  • Loading branch information
tonyzhaozh committed Dec 15, 2023
1 parent 47270bd commit 3794f96
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 3 deletions.
32 changes: 32 additions & 0 deletions commands.txt
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,38 @@ CUDA_VISIBLE_DEVICES=1 python3 imitate_episodes.py \
--batch_size 32 --lr 1e-4 --seed 0 \
--num_steps 1000000 --eval_every 1000000 --validate_every 5000 --save_every 5000

## Cotrain
conda activate mobile
export MUJOCO_GL=egl
cd /home/tonyzhao/Research/act-plus-plus
CUDA_VISIBLE_DEVICES=1 python3 imitate_episodes.py \
--task_name aloha_mobile_wipe_wine_cotrain \
--ckpt_dir /scr/tonyzhao/train_logs/wipe_wine_cotrain_diffusion_seed0 \
--policy_class Diffusion --chunk_size 32 \
--batch_size 32 --lr 1e-4 --seed 0 \
--num_steps 1000000 --eval_every 1000000 --validate_every 5000 --save_every 5000

# train no cotrain again with augmentations
conda activate mobile
export MUJOCO_GL=egl
cd /home/tonyzhao/Research/act-plus-plus
CUDA_VISIBLE_DEVICES=0 python3 imitate_episodes.py \
--task_name aloha_mobile_wipe_wine \
--ckpt_dir /scr/tonyzhao/train_logs/wipe_wine_diffusion_augmentation_seed0 \
--policy_class Diffusion --chunk_size 32 \
--batch_size 32 --lr 1e-4 --seed 0 \
--num_steps 1000000 --eval_every 1000000 --validate_every 5000 --save_every 5000

## Cotrain with augmentations
conda activate mobile
export MUJOCO_GL=egl
cd /home/tonyzhao/Research/act-plus-plus
CUDA_VISIBLE_DEVICES=1 python3 imitate_episodes.py \
--task_name aloha_mobile_wipe_wine_cotrain \
--ckpt_dir /scr/tonyzhao/train_logs/wipe_wine_cotrain_diffusion_augmentation_seed0 \
--policy_class Diffusion --chunk_size 32 \
--batch_size 32 --lr 1e-4 --seed 0 \
--num_steps 1000000 --eval_every 1000000 --validate_every 5000 --save_every 5000



Expand Down
29 changes: 26 additions & 3 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
import pickle
import fnmatch
import cv2
from time import time
from torch.utils.data import TensorDataset, DataLoader
import torchvision.transforms as transforms

import IPython
e = IPython.embed
Expand All @@ -25,7 +27,12 @@ def __init__(self, dataset_path_list, camera_names, norm_stats, episode_ids, epi
self.cumulative_len = np.cumsum(self.episode_len)
self.max_episode_len = max(episode_len)
self.policy_class = policy_class
self.__getitem__(0) # initialize self.is_sim
if self.policy_class == 'Diffusion':
self.augment_images = True
else:
self.augment_images = False
self.transformations = None
self.__getitem__(0) # initialize self.is_sim and self.transformations
self.is_sim = False

# def __len__(self):
Expand Down Expand Up @@ -102,6 +109,22 @@ def __getitem__(self, index):
# channel last
image_data = torch.einsum('k h w c -> k c h w', image_data)

# augmentation
if self.transformations is None:
print('Initializing transformations')
original_size = image_data.shape[2:]
ratio = 0.95
self.transformations = [
transforms.RandomCrop(size=[int(original_size[0] * ratio), int(original_size[1] * ratio)]),
transforms.Resize(original_size, antialias=True),
transforms.RandomRotation(degrees=[-5.0, 5.0], expand=False),
transforms.ColorJitter(brightness=0.3, contrast=0.4, saturation=0.5) #, hue=0.08)
]

if self.augment_images:
for transform in self.transformations:
image_data = transform(image_data)

# normalize image and change dtype to float
image_data = image_data / 255.0

Expand Down Expand Up @@ -238,8 +261,8 @@ def load_data(dataset_dir_l, name_filter, camera_names, batch_size_train, batch_
# construct dataset and dataloader
train_dataset = EpisodicDataset(dataset_path_list, camera_names, norm_stats, train_episode_ids, train_episode_len, chunk_size, policy_class)
val_dataset = EpisodicDataset(dataset_path_list, camera_names, norm_stats, val_episode_ids, val_episode_len, chunk_size, policy_class)
train_dataloader = DataLoader(train_dataset, batch_sampler=batch_sampler_train, pin_memory=True, num_workers=2, prefetch_factor=2)
val_dataloader = DataLoader(val_dataset, batch_sampler=batch_sampler_val, pin_memory=True, num_workers=2, prefetch_factor=2)
train_dataloader = DataLoader(train_dataset, batch_sampler=batch_sampler_train, pin_memory=True, num_workers=16, prefetch_factor=2)
val_dataloader = DataLoader(val_dataset, batch_sampler=batch_sampler_val, pin_memory=True, num_workers=8, prefetch_factor=2)

return train_dataloader, val_dataloader, norm_stats, train_dataset.is_sim

Expand Down

0 comments on commit 3794f96

Please sign in to comment.