Skip to content

Commit

Permalink
log qpos during eval
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkFzp committed Dec 12, 2023
1 parent 66e07ff commit 547a734
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 3 deletions.
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -141,3 +141,7 @@ dmypy.json

ckpts/
*log*

*.png
*.npy
play.ipynb
27 changes: 24 additions & 3 deletions imitate_episodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,13 @@
import IPython
e = IPython.embed

def get_auto_index(dataset_dir):
max_idx = 1000
for i in range(max_idx+1):
if not os.path.isfile(os.path.join(dataset_dir, f'qpos_{i}.npy')):
return i
raise Exception(f"Error getting auto index, or more than {max_idx} episodes")

def main(args):
set_seed(1)
# command line parameters
Expand Down Expand Up @@ -318,7 +325,8 @@ def eval_bc(config, ckpt_name, save_episode=True, num_rollouts=50):
if temporal_agg:
all_time_actions = torch.zeros([max_timesteps, max_timesteps+num_queries, 16]).cuda()

qpos_history = torch.zeros((1, max_timesteps, state_dim)).cuda()
# qpos_history = torch.zeros((1, max_timesteps, state_dim)).cuda()
qpos_history_raw = np.zeros((max_timesteps, state_dim))
image_list = [] # for visualization
qpos_list = []
target_qpos_list = []
Expand All @@ -345,9 +353,10 @@ def eval_bc(config, ckpt_name, save_episode=True, num_rollouts=50):
else:
image_list.append({'main': obs['image']})
qpos_numpy = np.array(obs['qpos'])
qpos_history_raw[t] = qpos_numpy
qpos = pre_process(qpos_numpy)
qpos = torch.from_numpy(qpos).float().cuda().unsqueeze(0)
qpos_history[:, t] = qpos
# qpos_history[:, t] = qpos
if t % query_frequency == 0:
curr_image = get_image(ts, camera_names)
# print('get image: ', time.time() - time2)
Expand Down Expand Up @@ -453,7 +462,19 @@ def eval_bc(config, ckpt_name, save_episode=True, num_rollouts=50):
plt.close()
if real_robot:
move_grippers([env.puppet_bot_left, env.puppet_bot_right], [PUPPET_GRIPPER_JOINT_OPEN] * 2, move_time=0.5) # open
pass
# save qpos_history_raw
log_id = get_auto_index(ckpt_dir)
np.save(os.path.join(ckpt_dir, f'qpos_{log_id}.npy'), qpos_history_raw)
# plot qpos_history_raw for each qpos dim using subplots
for i in range(state_dim):
plt.subplot(state_dim, 1, i+1)
plt.plot(qpos_history_raw[:, i])
# remove x axis
if i != state_dim - 1:
plt.xticks([])
plt.tight_layout()
plt.savefig(os.path.join(ckpt_dir, f'qpos_{log_id}.png'))


rewards = np.array(rewards)
episode_return = np.sum(rewards[rewards!=None])
Expand Down

0 comments on commit 547a734

Please sign in to comment.