Skip to content

Commit

Permalink
record_episode compress, add aloha_mobile_wash_pan task to constants
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkFzp committed Nov 5, 2023
1 parent 8f4060e commit 3533bb8
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 5 deletions.
6 changes: 6 additions & 0 deletions aloha_scripts/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,12 @@
'episode_len': 1300,
'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
},
'aloha_mobile_wash_pan':{
'dataset_dir': DATA_DIR + '/aloha_mobile_wash_pan',
'num_episodes': 50,
'episode_len': 1200,
'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
},
}

### ALOHA fixed constants
Expand Down
55 changes: 50 additions & 5 deletions aloha_scripts/record_episodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@
import time
import h5py
import argparse
import h5py_cache
import numpy as np
from tqdm import tqdm
import cv2

from constants import DT, START_ARM_POSE, TASK_CONFIGS
from constants import MASTER_GRIPPER_JOINT_MID, PUPPET_GRIPPER_JOINT_CLOSE, PUPPET_GRIPPER_JOINT_OPEN
Expand Down Expand Up @@ -104,6 +106,8 @@ def capture_one_episode(dt, max_timesteps, camera_names, dataset_dir, dataset_na
torque_on(master_bot_left)
torque_on(master_bot_right)
# Open puppet grippers
env.puppet_bot_left.dxl.robot_set_operating_modes("single", "gripper", "position")
env.puppet_bot_right.dxl.robot_set_operating_modes("single", "gripper", "position")
move_grippers([env.puppet_bot_left, env.puppet_bot_right], [PUPPET_GRIPPER_JOINT_OPEN] * 2, move_time=0.5)

freq_mean = print_dt_diagnosis(actual_dt_history)
Expand Down Expand Up @@ -147,17 +151,53 @@ def capture_one_episode(dt, max_timesteps, camera_names, dataset_dir, dataset_na
for cam_name in camera_names:
data_dict[f'/observations/images/{cam_name}'].append(ts.observation['images'][cam_name])

COMPRESS = True

if COMPRESS:
# JPEG compression
t0 = time.time()
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 50] # tried as low as 20, seems fine
compressed_len = []
for cam_name in camera_names:
image_list = data_dict[f'/observations/images/{cam_name}']
compressed_list = []
compressed_len.append([])
for image in image_list:
result, encoded_image = cv2.imencode('.jpg', image, encode_param) # 0.02 sec # cv2.imdecode(encoded_image, 1)
compressed_list.append(encoded_image)
compressed_len[-1].append(len(encoded_image))
data_dict[f'/observations/images/{cam_name}'] = compressed_list
print(f'compression: {time.time() - t0:.2f}s')

# pad so it has same length
t0 = time.time()
compressed_len = np.array(compressed_len)
padded_size = compressed_len.max()
for cam_name in camera_names:
compressed_image_list = data_dict[f'/observations/images/{cam_name}']
padded_compressed_image_list = []
for compressed_image in compressed_image_list:
padded_compressed_image = np.zeros(padded_size, dtype='uint8')
image_len = len(compressed_image)
padded_compressed_image[:image_len] = compressed_image
padded_compressed_image_list.append(padded_compressed_image)
data_dict[f'/observations/images/{cam_name}'] = padded_compressed_image_list
print(f'padding: {time.time() - t0:.2f}s')

# HDF5
t0 = time.time()
with h5py.File(dataset_path + '.hdf5', 'w', rdcc_nbytes=1024**2*2) as root:
with h5py_cache.File(dataset_path + '.hdf5', 'w', chunk_cache_mem_size=1024**2*2) as root:
root.attrs['sim'] = False
root.attrs['compress'] = COMPRESS
obs = root.create_group('observations')
image = obs.create_group('images')
for cam_name in camera_names:
_ = image.create_dataset(cam_name, (max_timesteps, 480, 640, 3), dtype='uint8',
chunks=(1, 480, 640, 3), )
# compression='gzip',compression_opts=2,)
# compression=32001, compression_opts=(0, 0, 0, 0, 9, 1, 1), shuffle=False)
if COMPRESS:
_ = image.create_dataset(cam_name, (max_timesteps, padded_size), dtype='uint8',
chunks=(1, padded_size), )
else:
_ = image.create_dataset(cam_name, (max_timesteps, 480, 640, 3), dtype='uint8',
chunks=(1, 480, 640, 3), )
_ = obs.create_dataset('qpos', (max_timesteps, 14))
_ = obs.create_dataset('qvel', (max_timesteps, 14))
_ = obs.create_dataset('effort', (max_timesteps, 14))
Expand All @@ -166,6 +206,11 @@ def capture_one_episode(dt, max_timesteps, camera_names, dataset_dir, dataset_na

for name, array in data_dict.items():
root[name][...] = array

if COMPRESS:
_ = root.create_dataset('compress_len', (len(camera_names), max_timesteps))
root['/compress_len'][...] = compressed_len

print(f'Saving: {time.time() - t0:.1f} secs')

return True
Expand Down

0 comments on commit 3533bb8

Please sign in to comment.