Skip to content

Commit

Permalink
refactor constants all tasks! get_episode_len.py
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkFzp committed Dec 12, 2023
1 parent 196e837 commit a0bdc7b
Show file tree
Hide file tree
Showing 2 changed files with 186 additions and 53 deletions.
205 changes: 152 additions & 53 deletions aloha_scripts/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,45 +4,36 @@

DATA_DIR = os.path.expanduser('~/data')
TASK_CONFIGS = {
'aloha_wear_shoe':{
'dataset_dir': DATA_DIR + '/aloha_wear_shoe',
'num_episodes': 50,
'aloha_mobile_dummy':{
'dataset_dir': DATA_DIR + '/aloha_mobile_dummy',
'episode_len': 1000,
'camera_names': ['cam_high', 'cam_low', 'cam_left_wrist', 'cam_right_wrist']
},
'aloha_mobile_grasp_pen':{
'dataset_dir': DATA_DIR + '/aloha_mobile_grasp_pen',
'num_episodes': 50,
'episode_len': 500,
'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
},
'aloha_mobile_grasp_pen_diverse':{
'dataset_dir': DATA_DIR + '/aloha_mobile_grasp_pen_diverse',
'num_episodes': 50,
'episode_len': 500,
'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
},
'aloha_mobile_grasp_pen_all':{
'dataset_dir': DATA_DIR + '/aloha_mobile_grasp_pen_all',
'num_episodes': 100,
'episode_len': 500,
'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
},
'aloha_mobile_object_to_cabinet':{
'dataset_dir': DATA_DIR + '/aloha_mobile_object_to_cabinet',
'num_episodes': 50,
'episode_len': 1500,

# wash_pan
'aloha_mobile_wash_pan':{
'dataset_dir': DATA_DIR + '/aloha_mobile_wash_pan',
'episode_len': 1100,
'train_ratio': 0.9,
'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
},
'aloha_mobile_glass_to_cabinet':{
'dataset_dir': DATA_DIR + '/aloha_mobile_glass_to_cabinet',
'num_episodes': 50,
'episode_len': 1500,
'aloha_mobile_wash_pan_cotrain':{
'dataset_dir': [
DATA_DIR + '/aloha_mobile_wash_pan',
DATA_DIR + '/aloha_compressed_dataset',
], # only the first dataset_dir is used for val
'stats_dir': [
DATA_DIR + '/aloha_mobile_wash_pan',
],
'sample_weights': [5, 5],
'train_ratio': 0.9, # ratio of train data from the first dataset_dir
'episode_len': 1100,
'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
},

# wipe_wine
'aloha_mobile_wipe_wine':{
'dataset_dir': DATA_DIR + '/aloha_mobile_wipe_wine',
'num_episodes': 50,
'episode_len': 1300,
'train_ratio': 0.9,
'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
Expand All @@ -57,47 +48,155 @@
],
'sample_weights': [5, 5],
'train_ratio': 0.9, # ratio of train data from the first dataset_dir
'episode_len': 1100,
'episode_len': 1300,
'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
},
'aloha_mobile_wash_pan':{
'dataset_dir': DATA_DIR + '/aloha_mobile_wash_pan',
'num_episodes': 50,
'episode_len': 1100,

# cabinet
'aloha_mobile_cabinet':{
'dataset_dir': [
DATA_DIR + '/aloha_mobile_cabinet',
DATA_DIR + '/aloha_mobile_cabinet_handles', # 200
DATA_DIR + '/aloha_mobile_cabinet_grasp_pots', # 200
], # only the first dataset_dir is used for val
'stats_dir': [
DATA_DIR + '/aloha_mobile_cabinet',
],
'sample_weights': [6, 1, 1],
'train_ratio': 0.99, # ratio of train data from the first dataset_dir
'episode_len': 1500,
'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
},
'aloha_mobile_cabinet_cotrain':{
'dataset_dir': [
DATA_DIR + '/aloha_mobile_cabinet',
DATA_DIR + '/aloha_mobile_cabinet_handles',
DATA_DIR + '/aloha_mobile_cabinet_grasp_pots',
DATA_DIR + '/aloha_compressed_dataset',
], # only the first dataset_dir is used for val
'stats_dir': [
DATA_DIR + '/aloha_mobile_cabinet',
],
'sample_weights': [6, 1, 1, 2],
'train_ratio': 0.99, # ratio of train data from the first dataset_dir
'episode_len': 1500,
'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
},

# elevator
'aloha_mobile_elevator':{
'dataset_dir': DATA_DIR + '/aloha_mobile_elevator',
'train_ratio': 0.99,
'episode_len': 8500,
'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
},
'aloha_mobile_elevator_truncated':{
'dataset_dir': [
DATA_DIR + '/aloha_mobile_elevator_truncated',
DATA_DIR + '/aloha_mobile_elevator_2', # 1200
DATA_DIR + '/aloha_mobile_elevator_button', # 800
], # only the first dataset_dir is used for val
'stats_dir': [
DATA_DIR + '/aloha_mobile_elevator_truncated',
DATA_DIR + '/aloha_mobile_elevator_2',
],
'sample_weights': [3, 3, 2],
'train_ratio': 0.99, # ratio of train data from the first dataset_dir
'episode_len': 2250,
'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
},
'aloha_mobile_elevator_truncated_cotrain':{
'dataset_dir': [
DATA_DIR + '/aloha_mobile_elevator_truncated',
DATA_DIR + '/aloha_mobile_elevator_2',
DATA_DIR + '/aloha_mobile_elevator_button',
DATA_DIR + '/aloha_compressed_dataset',
], # only the first dataset_dir is used for val
'stats_dir': [
DATA_DIR + '/aloha_mobile_elevator_truncated',
DATA_DIR + '/aloha_mobile_elevator_2',
],
'sample_weights': [3, 3, 2, 1],
'train_ratio': 0.99, # ratio of train data from the first dataset_dir
'episode_len': 2250,
'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
},

# high_five
'aloha_mobile_high_five':{
'dataset_dir': DATA_DIR + '/aloha_mobile_high_five',
'train_ratio': 0.9,
'episode_len': 2000,
'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
},
'aloha_mobile_wash_pan_cotrain':{
'aloha_mobile_high_five_cotrain':{
'dataset_dir': [
DATA_DIR + '/aloha_mobile_wash_pan',
DATA_DIR + '/aloha_mobile_high_five',
DATA_DIR + '/aloha_compressed_dataset',
], # only the first dataset_dir is used for val
'stats_dir': [
DATA_DIR + '/aloha_mobile_wash_pan',
DATA_DIR + '/aloha_mobile_high_five',
],
'sample_weights': [5, 5],
'sample_weights': [7.5, 2.5],
'train_ratio': 0.9, # ratio of train data from the first dataset_dir
'episode_len': 1100,
'episode_len': 2000,
'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
},
'aloha_mobile_dummy':{
'dataset_dir': DATA_DIR + '/aloha_mobile_dummy',
'num_episodes': 50,
'episode_len': 1000,

# chair
'aloha_mobile_chair':{
'dataset_dir': DATA_DIR + '/aloha_mobile_chair',
'train_ratio': 0.95,
'episode_len': 2400,
'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
},
'aloha_mobile_fork':{
'dataset_dir': DATA_DIR + '/aloha_mobile_fork',
'num_episodes': 50,
'episode_len': 400,
'aloha_mobile_chair_truncated':{
'dataset_dir': DATA_DIR + '/aloha_mobile_chair_truncated',
'train_ratio': 0.95,
'episode_len': 2000,
'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
},
'aloha_mobile_elevator':{
'dataset_dir': DATA_DIR + '/aloha_mobile_elevator',
'num_episodes': 50,
'episode_len': 8500,
'aloha_mobile_chair_truncated_cotrain':{
'dataset_dir': [
DATA_DIR + '/aloha_mobile_chair_truncated',
DATA_DIR + '/aloha_compressed_dataset',
], # only the first dataset_dir is used for val
'stats_dir': [
DATA_DIR + '/aloha_mobile_chair_truncated',
],
'sample_weights': [5, 5],
'train_ratio': 0.95, # ratio of train data from the first dataset_dir
'episode_len': 2000,
'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
},

# shrimp
'aloha_mobile_shrimp':{
'dataset_dir': DATA_DIR + '/aloha_mobile_shrimp',
'train_ratio': 0.99,
'episode_len': 4500,
'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
},
'aloha_mobile_shrimp_truncated':{
'dataset_dir': DATA_DIR + '/aloha_mobile_shrimp_truncated',
'train_ratio': 0.99, # ratio of train data from the first dataset_dir
'episode_len': 3750,
'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
},
'aloha_mobile_shrimp_2_cotrain':{
'dataset_dir': [
DATA_DIR + '/aloha_mobile_shrimp_2',
DATA_DIR + '/aloha_mobile_shrimp_before_spatula_down', # 2200
DATA_DIR + '/aloha_compressed_dataset',
], # only the first dataset_dir is used for val
'stats_dir': [
DATA_DIR + '/aloha_mobile_shrimp_2',
],
'sample_weights': [5, 3, 2],
'train_ratio': 0.99, # ratio of train data from the first dataset_dir
'episode_len': 4500,
'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
}
},
}

### ALOHA fixed constants
Expand Down
34 changes: 34 additions & 0 deletions aloha_scripts/get_episode_len.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import os
import numpy as np
import cv2
import h5py
import argparse

import matplotlib.pyplot as plt
from constants import DT

import IPython
e = IPython.embed

def main(args):
dataset_dir = args['dataset_dir']
episode_idx = args['episode_idx']
dataset_name = f'episode_{episode_idx}'

dataset_path = os.path.join(dataset_dir, dataset_name + '.hdf5')
if not os.path.isfile(dataset_path):
print(f'Dataset does not exist at \n{dataset_path}\n')
exit()

with h5py.File(dataset_path, 'r') as root:
is_sim = root.attrs['sim']
compressed = root.attrs.get('compress', False)
qpos = root['/observations/qpos'][()]
print(f'dataset_name: {dataset_name}, episode {episode_idx}, len: {len(qpos)}')

if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_dir', action='store', type=str, help='Dataset dir.', required=True)
parser.add_argument('--episode_idx', action='store', type=int, help='Episode index.', required=False)
parser.add_argument('--ismirror', action='store_true')
main(vars(parser.parse_args()))

0 comments on commit a0bdc7b

Please sign in to comment.