Skip to content

Commit

Permalink
VINN works for sim transfer cube
Browse files Browse the repository at this point in the history
  • Loading branch information
tonyzhaozh committed Dec 11, 2023
1 parent c460687 commit 91f1bbf
Show file tree
Hide file tree
Showing 6 changed files with 621 additions and 0 deletions.
1 change: 1 addition & 0 deletions byol_pytorch
Submodule byol_pytorch added at 25e5b3
50 changes: 50 additions & 0 deletions commands.txt
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,56 @@ CUDA_VISIBLE_DEVICES=1 python3 imitate_episodes.py \
--batch_size 32 --lr 1e-4 --seed 0 \
--num_steps 100000 --eval_every 2000 --validate_every 2000 --save_every 2000

# Dec 10

######################## more diffusion ########################
conda activate mobile
export MUJOCO_GL=egl
cd /home/tonyzhao/Research/act-plus-plus
CUDA_VISIBLE_DEVICES=0 python3 imitate_episodes.py \
--task_name sim_transfer_cube_scripted \
--ckpt_dir /scr/tonyzhao/train_logs/cube_scripted_diffusion_sweep_3_chunk64 \
--policy_class Diffusion --chunk_size 64 \
--batch_size 32 --lr 1e-4 --seed 0 \
--num_steps 200000 --eval_every 4000 --validate_every 4000 --save_every 4000




######################## VINN ########################


conda activate mobile
cd /home/tonyzhao/Research/act-plus-plus/byol_pytorch/examples/lightning
CUDA_VISIBLE_DEVICES=1 python3 train.py --dataset_dir /scr/tonyzhao/datasets/sim_transfer_cube_scripted --cam_name top --seed 0

conda activate mobile
cd /home/tonyzhao/Research/act-plus-plus/byol_pytorch/examples/lightning
CUDA_VISIBLE_DEVICES=0 python3 train.py --dataset_dir /scr/tonyzhao/datasets/sim_transfer_cube_scripted --cam_name left_wrist --seed 0

conda activate mobile
cd /home/tonyzhao/Research/act-plus-plus/byol_pytorch/examples/lightning
CUDA_VISIBLE_DEVICES=1 python3 train.py --dataset_dir /scr/tonyzhao/datasets/sim_transfer_cube_scripted --cam_name right_wrist --seed 0

conda activate mobile
cd /home/tonyzhao/Research/act-plus-plus
TASK_NAME=sim_transfer_cube_scripted
python3 vinn_cache_feature.py --ckpt_path /home/tonyzhao/Research/act-plus-plus/byol_pytorch/examples/lightning/byol-${TASK_NAME}-DUMMY-seed-0.pt

TASK_NAME=sim_transfer_cube_scripted
python3 vinn_select_k.py \
--dataset_dir /scr/tonyzhao/datasets/sim_transfer_cube_scripted \
--ckpt_dir /scr/tonyzhao/train_logs/VINN-eval-seed-0-test

python3 vinn_eval.py \
--dataset_dir /scr/tonyzhao/datasets/sim_transfer_cube_scripted \
--model_dir /home/tonyzhao/Research/act-plus-plus/byol_pytorch/examples/lightning/byol-${TASK_NAME}-DUMMY-seed-0.pt \
--ckpt_dir /scr/tonyzhao/train_logs/VINN-eval-seed-0-test \
--task_name $TASK_NAME

## TODO
make sure env is consistent
tune a bit more


---------------------------------------------------------------------------------------
Expand Down
1 change: 1 addition & 0 deletions utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,7 @@ def find_all_hdf5(dataset_dir, skip_mirrored_data):
hdf5_files = []
for root, dirs, files in os.walk(dataset_dir):
for filename in fnmatch.filter(files, '*.hdf5'):
if 'features' in filename: continue
if skip_mirrored_data and 'mirror' in filename:
continue
hdf5_files.append(os.path.join(root, filename))
Expand Down
132 changes: 132 additions & 0 deletions vinn_cache_feature.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
import torch
import argparse
import pathlib
from torch import nn
import torchvision
import os
import time
import h5py
import h5py
from torchvision import models, transforms
from PIL import Image

import IPython
e = IPython.embed


def chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i:i + n]

def expand_greyscale(t):
return t.expand(3, -1, -1)


def main(args):
#################################################
batch_size = 256
#################################################

ckpt_path = args.ckpt_path
ckpt_name = pathlib.PurePath(ckpt_path).name
dataset_name = ckpt_name.split('-')[1]
repr_type = ckpt_name.split('-')[0]
seed = int(ckpt_name.split('-')[-1][:-3])

dataset_dir = f'/scr/tonyzhao/datasets/{dataset_name}'

episode_idxs = [int(name.split('_')[1].split('.')[0]) for name in os.listdir(dataset_dir) if ('.hdf5' in name) and ('features' not in name)]
episode_idxs.sort()
assert len(episode_idxs) == episode_idxs[-1] + 1 # no holes
num_episodes = len(episode_idxs)

feature_extractors = {}

for episode_idx in range(num_episodes):

# load all images
dataset_path = os.path.join(dataset_dir, f'episode_{episode_idx}.hdf5')
with h5py.File(dataset_path, 'r') as root:
image_dict = {}
camera_names = list(root[f'/observations/images/'].keys())
print(f'Camera names: {camera_names}')
for cam_name in camera_names:
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][:]

# load pretrain nets after cam names are known
if not feature_extractors:
for cam_name in camera_names:
resnet = torchvision.models.resnet18(pretrained=True)
loading_status = resnet.load_state_dict(torch.load(ckpt_path.replace('DUMMY', cam_name)))
print(cam_name, loading_status)
resnet = nn.Sequential(*list(resnet.children())[:-1])
resnet = resnet.cuda()
resnet.eval()
feature_extractors[cam_name] = resnet

# inference with resnet
feature_dict = {}
for cam_name, images in image_dict.items():
# Preprocess images
image_size = 120 # TODO NOTICE: reduced resolution
transform = transforms.Compose([
transforms.Resize(image_size), # will scale the image
transforms.CenterCrop(image_size),
transforms.ToTensor(),
transforms.Lambda(expand_greyscale),
transforms.Normalize(
mean=torch.tensor([0.485, 0.456, 0.406]),
std=torch.tensor([0.229, 0.224, 0.225])),
])
processed_images = []
for image in images:
image = Image.fromarray(image)
image = transform(image)
processed_images.append(image)
processed_images = torch.stack(processed_images).cuda()

# query the model
all_features = []
with torch.inference_mode():
for batch in chunks(processed_images, batch_size):
features = feature_extractors[cam_name](batch)
features = features.squeeze(axis=3).squeeze(axis=2)
all_features.append(features)
all_features = torch.cat(all_features, axis=0)
max_timesteps = all_features.shape[0]
feature_dict[cam_name] = all_features

# TODO START diagnostics
# first_image = images[0]
# first_processed_image = processed_images[0].cpu().numpy()
# first_feature = all_features[0].cpu().numpy()
# import numpy as np
# np.save('first_image.npy', first_image)
# np.save('first_processed_image.npy', first_processed_image)
# np.save('first_feature.npy', first_feature)
# torch.save(resnet.state_dict(), 'rn.ckpt')
# e()
# exit()
# TODO END diagnostics


# save
dataset_path = os.path.join(dataset_dir, f'{repr_type}_features_seed{seed}_episode_{episode_idx}.hdf5')
print(dataset_path)
# HDF5
t0 = time.time()
with h5py.File(dataset_path, 'w', rdcc_nbytes=1024 ** 2 * 2) as root:
features = root.create_group('features')
for cam_name, array in feature_dict.items():
cam_feature = features.create_dataset(cam_name, (max_timesteps, 512))
features[cam_name][...] = array.cpu().numpy()
print(f'Saving: {time.time() - t0:.1f} secs\n')


if __name__ == '__main__':
parser = argparse.ArgumentParser(description='cache features')
parser.add_argument('--ckpt_path', type=str, required=True, help='ckpt_path')
args = parser.parse_args()

main(args)
Loading

0 comments on commit 91f1bbf

Please sign in to comment.