Skip to content

Commit

Permalink
add tasks, visualize base action in visualize_episodes.py
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkFzp committed Oct 28, 2023
1 parent 4832ecd commit a398c2a
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 5 deletions.
12 changes: 12 additions & 0 deletions aloha_scripts/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,18 @@
'episode_len': 500,
'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
},
'aloha_mobile_grasp_pen_all':{
'dataset_dir': DATA_DIR + '/aloha_mobile_grasp_pen_all',
'num_episodes': 100,
'episode_len': 500,
'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
},
'aloha_mobile_object_to_cabinet':{
'dataset_dir': DATA_DIR + '/aloha_mobile_object_to_cabinet',
'num_episodes': 50,
'episode_len': 1700,
'camera_names': ['cam_high', 'cam_left_wrist', 'cam_right_wrist']
},
}

### ALOHA fixed constants
Expand Down
1 change: 1 addition & 0 deletions aloha_scripts/record_episodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def capture_one_episode(dt, max_timesteps, camera_names, dataset_dir, dataset_na
- qvel (14,) 'float64'
action (14,) 'float64'
base_action (2,) 'float64'
"""

data_dict = {
Expand Down
40 changes: 35 additions & 5 deletions aloha_scripts/visualize_episodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
STATE_NAMES = JOINT_NAMES + ["gripper"]
BASE_STATE_NAMES = ["linear_vel", "angular_vel"]

def load_hdf5(dataset_dir, dataset_name):
dataset_path = os.path.join(dataset_dir, dataset_name + '.hdf5')
Expand All @@ -25,22 +26,24 @@ def load_hdf5(dataset_dir, dataset_name):
qvel = root['/observations/qvel'][()]
effort = root['/observations/effort'][()]
action = root['/action'][()]
base_action = root['/base_action'][()]
image_dict = dict()
for cam_name in root[f'/observations/images/'].keys():
image_dict[cam_name] = root[f'/observations/images/{cam_name}'][()]

return qpos, qvel, effort, action, image_dict
return qpos, qvel, effort, action, base_action, image_dict

def main(args):
dataset_dir = args['dataset_dir']
episode_idx = args['episode_idx']
dataset_name = f'episode_{episode_idx}'

qpos, qvel, effort, action, image_dict = load_hdf5(dataset_dir, dataset_name)
qpos, qvel, effort, action, base_action, image_dict = load_hdf5(dataset_dir, dataset_name)
save_videos(image_dict, DT, video_path=os.path.join(dataset_dir, dataset_name + '_video.mp4'))
visualize_joints(qpos, action, plot_path=os.path.join(dataset_dir, dataset_name + '_qpos.png'))
visualize_single(effort, 'effort', plot_path=os.path.join(dataset_dir, dataset_name + '_effort.png'))
visualize_single(action - qpos, 'tracking_error', plot_path=os.path.join(dataset_dir, dataset_name + '_error.png'))
# visualize_single(effort, 'effort', plot_path=os.path.join(dataset_dir, dataset_name + '_effort.png'))
# visualize_single(action - qpos, 'tracking_error', plot_path=os.path.join(dataset_dir, dataset_name + '_error.png'))
visualize_base(base_action, plot_path=os.path.join(dataset_dir, dataset_name + '_base_action.png'))
# visualize_timestamp(t_list, dataset_path) # TODO addn timestamp back


Expand Down Expand Up @@ -90,7 +93,7 @@ def visualize_joints(qpos_list, command_list, plot_path=None, ylim=None, label_o
num_ts, num_dim = qpos.shape
h, w = 2, num_dim
num_figs = num_dim
fig, axs = plt.subplots(num_figs, 1, figsize=(w, h * num_figs))
fig, axs = plt.subplots(num_figs, 1, figsize=(8, 2 * num_dim))

# plot joint state
all_names = [name + '_left' for name in STATE_NAMES] + [name + '_right' for name in STATE_NAMES]
Expand Down Expand Up @@ -141,6 +144,33 @@ def visualize_single(efforts_list, label, plot_path=None, ylim=None, label_overw
print(f'Saved effort plot to: {plot_path}')
plt.close()

def visualize_base(readings, plot_path=None):
readings = np.array(readings) # ts, dim
num_ts, num_dim = readings.shape
num_figs = num_dim
fig, axs = plt.subplots(num_figs, 1, figsize=(8, 2 * num_dim))

# plot joint state
all_names = BASE_STATE_NAMES
for dim_idx in range(num_dim):
ax = axs[dim_idx]
ax.plot(readings[:, dim_idx], label='raw')
ax.plot(np.convolve(readings[:, dim_idx], np.ones(20)/20, mode='same'), label='smoothed_20')
ax.plot(np.convolve(readings[:, dim_idx], np.ones(10)/10, mode='same'), label='smoothed_10')
ax.plot(np.convolve(readings[:, dim_idx], np.ones(5)/5, mode='same'), label='smoothed_5')
ax.set_title(f'Joint {dim_idx}: {all_names[dim_idx]}')
ax.legend()

# if ylim:
# for dim_idx in range(num_dim):
# ax = axs[dim_idx]
# ax.set_ylim(ylim)

plt.tight_layout()
plt.savefig(plot_path)
print(f'Saved effort plot to: {plot_path}')
plt.close()


def visualize_timestamp(t_list, dataset_path):
plot_path = dataset_path.replace('.pkl', '_timestamp.png')
Expand Down

0 comments on commit a398c2a

Please sign in to comment.