Skip to content

Commit

Permalink
Merge branch 'main' of github.com:MarkFzp/mobile-aloha into main
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkFzp committed Nov 1, 2023
2 parents cdd65c4 + 9a325e6 commit 24a2c83
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 7 deletions.
6 changes: 6 additions & 0 deletions aloha_scripts/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,12 @@
'episode_len': 500,
'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
},
'aloha_mobile_grasp_pen_all':{
'dataset_dir': DATA_DIR + '/aloha_mobile_grasp_pen_all',
'num_episodes': 100,
'episode_len': 500,
'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
},
'aloha_mobile_object_to_cabinet':{
'dataset_dir': DATA_DIR + '/aloha_mobile_object_to_cabinet',
'num_episodes': 50,
Expand Down
1 change: 1 addition & 0 deletions aloha_scripts/record_episodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def capture_one_episode(dt, max_timesteps, camera_names, dataset_dir, dataset_na
- qvel (14,) 'float64'
action (14,) 'float64'
base_action (2,) 'float64'
"""

data_dict = {
Expand Down
52 changes: 45 additions & 7 deletions aloha_scripts/visualize_episodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
STATE_NAMES = JOINT_NAMES + ["gripper"]
BASE_STATE_NAMES = ["linear_vel", "angular_vel"]

def load_hdf5(dataset_dir, dataset_name):
dataset_path = os.path.join(dataset_dir, dataset_name + '.hdf5')
Expand All @@ -23,24 +24,33 @@ def load_hdf5(dataset_dir, dataset_name):
is_sim = root.attrs['sim']
qpos = root['/observations/qpos'][()]
qvel = root['/observations/qvel'][()]
effort = root['/observations/effort'][()]
if 'effort' in root.keys():
effort = root['/observations/effort'][()]
else:
effort = None
action = root['/action'][()]
base_action = root['/base_action'][()]
image_dict = dict()
for cam_name in root[f'/observations/images/'].keys():
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][()]

return qpos, qvel, effort, action, image_dict
return qpos, qvel, effort, action, base_action, image_dict

def main(args):
dataset_dir = args['dataset_dir']
episode_idx = args['episode_idx']
dataset_name = f'episode_{episode_idx}'
ismirror = args['ismirror']
if ismirror:
dataset_name = f'mirror_episode_{episode_idx}'
else:
dataset_name = f'episode_{episode_idx}'

qpos, qvel, effort, action, image_dict = load_hdf5(dataset_dir, dataset_name)
qpos, qvel, effort, action, base_action, image_dict = load_hdf5(dataset_dir, dataset_name)
save_videos(image_dict, DT, video_path=os.path.join(dataset_dir, dataset_name + '_video.mp4'))
visualize_joints(qpos, action, plot_path=os.path.join(dataset_dir, dataset_name + '_qpos.png'))
visualize_single(effort, 'effort', plot_path=os.path.join(dataset_dir, dataset_name + '_effort.png'))
visualize_single(action - qpos, 'tracking_error', plot_path=os.path.join(dataset_dir, dataset_name + '_error.png'))
# visualize_single(effort, 'effort', plot_path=os.path.join(dataset_dir, dataset_name + '_effort.png'))
# visualize_single(action - qpos, 'tracking_error', plot_path=os.path.join(dataset_dir, dataset_name + '_error.png'))
visualize_base(base_action, plot_path=os.path.join(dataset_dir, dataset_name + '_base_action.png'))
# visualize_timestamp(t_list, dataset_path) # TODO addn timestamp back


Expand Down Expand Up @@ -90,7 +100,7 @@ def visualize_joints(qpos_list, command_list, plot_path=None, ylim=None, label_o
num_ts, num_dim = qpos.shape
h, w = 2, num_dim
num_figs = num_dim
fig, axs = plt.subplots(num_figs, 1, figsize=(w, h * num_figs))
fig, axs = plt.subplots(num_figs, 1, figsize=(8, 2 * num_dim))

# plot joint state
all_names = [name + '_left' for name in STATE_NAMES] + [name + '_right' for name in STATE_NAMES]
Expand Down Expand Up @@ -141,6 +151,33 @@ def visualize_single(efforts_list, label, plot_path=None, ylim=None, label_overw
print(f'Saved effort plot to: {plot_path}')
plt.close()

def visualize_base(readings, plot_path=None):
readings = np.array(readings) # ts, dim
num_ts, num_dim = readings.shape
num_figs = num_dim
fig, axs = plt.subplots(num_figs, 1, figsize=(8, 2 * num_dim))

# plot joint state
all_names = BASE_STATE_NAMES
for dim_idx in range(num_dim):
ax = axs[dim_idx]
ax.plot(readings[:, dim_idx], label='raw')
ax.plot(np.convolve(readings[:, dim_idx], np.ones(20)/20, mode='same'), label='smoothed_20')
ax.plot(np.convolve(readings[:, dim_idx], np.ones(10)/10, mode='same'), label='smoothed_10')
ax.plot(np.convolve(readings[:, dim_idx], np.ones(5)/5, mode='same'), label='smoothed_5')
ax.set_title(f'Joint {dim_idx}: {all_names[dim_idx]}')
ax.legend()

# if ylim:
# for dim_idx in range(num_dim):
# ax = axs[dim_idx]
# ax.set_ylim(ylim)

plt.tight_layout()
plt.savefig(plot_path)
print(f'Saved effort plot to: {plot_path}')
plt.close()


def visualize_timestamp(t_list, dataset_path):
plot_path = dataset_path.replace('.pkl', '_timestamp.png')
Expand Down Expand Up @@ -173,4 +210,5 @@ def visualize_timestamp(t_list, dataset_path):
parser = argparse.ArgumentParser()
parser.add_argument('--dataset_dir', action='store', type=str, help='Dataset dir.', required=True)
parser.add_argument('--episode_idx', action='store', type=int, help='Episode index.', required=False)
parser.add_argument('--ismirror', action='store_true')
main(vars(parser.parse_args()))

0 comments on commit 24a2c83

Please sign in to comment.