Skip to content

Commit

Permalink
add task configs to constant.py to reduce command line arguments
Browse files Browse the repository at this point in the history
  • Loading branch information
tonyzhaozh committed Mar 6, 2023
1 parent 092735d commit 5a33ee8
Show file tree
Hide file tree
Showing 11 changed files with 131 additions and 116 deletions.
25 changes: 8 additions & 17 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,11 @@ To set up a new terminal, run:

### Simulated experiments

We use ``transfer_cube`` task in the examples below. Another option is ``insertion``.
We use ``sim_transfer_cube_scripted`` task in the examples below. Another option is ``sim_insertion_scripted``.
To generated 50 episodes of scripted data, run:

python3 record_sim_episodes.py \
--task_name transfer_cube \
--task_name sim_transfer_cube_scripted \
--dataset_dir <data save dir> \
--num_episodes 50

Expand All @@ -64,24 +64,15 @@ To train ACT:

# Transfer Cube task
python3 imitate_episodes.py \
--dataset_dir <data save dir> \
--ckpt_dir <ckpt dir> \
--policy_class ACT --kl_weight 10 --chunk_size 100 --hidden_dim 256 --batch_size 8 --dim_feedforward 2048 \
--task_name transfer_cube --seed 0 \
--temporal_agg \
--num_epochs 1000 --lr 1e-4

# Bimanual Insertion task
python3 imitate_episodes.py \
--dataset_dir <data save dir> \
--task_name sim_transfer_cube_scripted \
--ckpt_dir <ckpt dir> \
--policy_class ACT --kl_weight 10 --chunk_size 100 --hidden_dim 256 --batch_size 8 --dim_feedforward 2048 \
--task_name insertion --seed 0 \
--temporal_agg \
--num_epochs 2000 --lr 1e-5
--policy_class ACT --kl_weight 10 --chunk_size 100 --hidden_dim 512 --batch_size 8 --dim_feedforward 3200 \
--num_epochs 2000 --lr 1e-5 \
--seed 0


To evaluate the policy, run the same command but add ``--eval``. The success rate
should be around 85% for transfer cube, and around 50% for insertion.
should be around 90% for transfer cube, and around 50% for insertion.
Videos will be saved to ``<ckpt_dir>`` for each rollout.
You can also add ``--onscreen_render`` to see real-time rendering during evaluation.

Expand Down
1 change: 0 additions & 1 deletion assets/scene.xml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@

<camera name="left_pillar" pos="-0.5 0.2 0.6" fovy="78" mode="targetbody" target="table"/>
<camera name="right_pillar" pos="0.5 0.2 0.6" fovy="78" mode="targetbody" target="table"/>
<camera name="main" pos="0 -0.2 0.4" fovy="78" mode="targetbody" target="midair"/>
<camera name="top" pos="0 0.6 0.8" fovy="78" mode="targetbody" target="table"/>
<camera name="angle" pos="0 0 0.6" fovy="78" mode="targetbody" target="table"/>
<camera name="front_close" pos="0 0.2 0.4" fovy="78" mode="targetbody" target="vx300s_left/camera_focus"/>
Expand Down
41 changes: 32 additions & 9 deletions constants.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,41 @@
import pathlib

### Parameters that changes across tasks
EPISODE_LEN = 600
### Task parameters
DATA_DIR = '<put your data dir here>'
SIM_TASK_CONFIGS = {
'sim_transfer_cube_scripted':{
'dataset_dir': DATA_DIR + '/sim_transfer_cube_scripted',
'num_episodes': 50,
'episode_len': 400,
'camera_names': ['top']
},

### ALOHA fixed constants
'sim_transfer_cube_human':{
'dataset_dir': DATA_DIR + '/sim_transfer_cube_human',
'num_episodes': 50,
'episode_len': 400,
'camera_names': ['top']
},

'sim_insertion_scripted': {
'dataset_dir': DATA_DIR + '/sim_insertion_scripted',
'num_episodes': 50,
'episode_len': 400,
'camera_names': ['top']
},

'sim_insertion_human': {
'dataset_dir': DATA_DIR + '/sim_insertion_human',
'num_episodes': 50,
'episode_len': 400,
'camera_names': ['top']
},
}

### Simulation envs fixed constants
DT = 0.02
JOINT_NAMES = ["waist", "shoulder", "elbow", "forearm_roll", "wrist_angle", "wrist_rotate"]
CAMERA_NAMES = ['cam_high', 'cam_low', 'cam_left_wrist', 'cam_right_wrist'] # defines the number and ordering of cameras
BOX_INIT_POSE = [0.2, 0.5, 0.05, 1, 0, 0, 0]
START_ARM_POSE = [0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239, 0, -0.96, 1.16, 0, -0.3, 0, 0.02239, -0.02239]
SIM_CAMERA_NAMES = ['main']

SIM_EPISODE_LEN_TRANSFER_CUBE = 400
SIM_EPISODE_LEN_INSERTION = 400

XML_DIR = str(pathlib.Path(__file__).parent.resolve()) + '/assets/' # note: absolute path

Expand Down
1 change: 0 additions & 1 deletion detr/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def get_args_parser():
# repeat args in imitate_episodes just to avoid error. Will not be used
parser.add_argument('--eval', action='store_true')
parser.add_argument('--onscreen_render', action='store_true')
parser.add_argument('--dataset_dir', action='store', type=str, help='dataset_dir', required=True)
parser.add_argument('--ckpt_dir', action='store', type=str, help='ckpt_dir', required=True)
parser.add_argument('--policy_class', action='store', type=str, help='policy_class, capitalize', required=True)
parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True)
Expand Down
15 changes: 9 additions & 6 deletions ee_sim_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,15 @@ def make_ee_sim_env(task_name):
right_gripper_qvel (1)] # normalized gripper velocity (pos: opening, neg: closing)
"images": {"main": (480x640x3)} # h, w, c, dtype='uint8'
"""
xml_path = os.path.join(XML_DIR, f'bimanual_viperx_ee_{task_name}.xml')
physics = mujoco.Physics.from_xml_path(xml_path)
if task_name == 'transfer_cube':
if 'sim_transfer_cube' in task_name:
xml_path = os.path.join(XML_DIR, f'bimanual_viperx_ee_transfer_cube.xml')
physics = mujoco.Physics.from_xml_path(xml_path)
task = TransferCubeEETask(random=False)
env = control.Environment(physics, task, time_limit=20, control_timestep=DT,
n_sub_steps=None, flat_observation=False)
elif task_name == 'insertion':
elif 'sim_insertion' in task_name:
xml_path = os.path.join(XML_DIR, f'bimanual_viperx_ee_insertion.xml')
physics = mujoco.Physics.from_xml_path(xml_path)
task = InsertionEETask(random=False)
env = control.Environment(physics, task, time_limit=20, control_timestep=DT,
n_sub_steps=None, flat_observation=False)
Expand Down Expand Up @@ -133,8 +135,9 @@ def get_observation(self, physics):
obs['qvel'] = self.get_qvel(physics)
obs['env_state'] = self.get_env_state(physics)
obs['images'] = dict()
obs['images']['main'] = physics.render(height=480, width=640, camera_id='main') # TODO hardcoded camera name

obs['images']['top'] = physics.render(height=480, width=640, camera_id='top')
obs['images']['angle'] = physics.render(height=480, width=640, camera_id='angle')
obs['images']['vis'] = physics.render(height=480, width=640, camera_id='front_close')
# used in scripted policy to obtain starting pose
obs['mocap_pose_left'] = np.concatenate([physics.data.mocap_pos[0], physics.data.mocap_quat[0]]).copy()
obs['mocap_pose_right'] = np.concatenate([physics.data.mocap_pos[1], physics.data.mocap_quat[1]]).copy()
Expand Down
59 changes: 29 additions & 30 deletions imitate_episodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@
from tqdm import tqdm
from einops import rearrange

from constants import DT, SIM_EPISODE_LEN_TRANSFER_CUBE, SIM_EPISODE_LEN_INSERTION, EPISODE_LEN
from constants import PUPPET_GRIPPER_JOINT_OPEN, CAMERA_NAMES, SIM_CAMERA_NAMES
from constants import DT
from constants import PUPPET_GRIPPER_JOINT_OPEN
from utils import load_data # data functions
from utils import sample_box_pose, sample_insertion_pose # robot functions
from utils import compute_dict_mean, set_seed, detach_dict # helper functions
Expand All @@ -26,16 +26,27 @@ def main(args):
# command line parameters
is_eval = args['eval']
ckpt_dir = args['ckpt_dir']
dataset_dir = args['dataset_dir']
policy_class = args['policy_class']
onscreen_render = args['onscreen_render']
task_name = args['task_name']
batch_size_train = args['batch_size']
batch_size_val = args['batch_size']
num_epochs = args['num_epochs']

# get task parameters
is_sim = task_name[:4] == 'sim_'
if is_sim:
from constants import SIM_TASK_CONFIGS
task_config = SIM_TASK_CONFIGS[task_name]
else:
from aloha_scripts.constants import TASK_CONFIGS
task_config = TASK_CONFIGS[task_name]
dataset_dir = task_config['dataset_dir']
num_episodes = task_config['num_episodes']
episode_len = task_config['episode_len']
camera_names = task_config['camera_names']

# fixed parameters
num_episodes = 50
state_dim = 14
lr_backbone = 1e-5
backbone = 'resnet18'
Expand All @@ -53,41 +64,31 @@ def main(args):
'enc_layers': enc_layers,
'dec_layers': dec_layers,
'nheads': nheads,
'camera_names': camera_names,
}
elif policy_class == 'CNNMLP':
policy_config = {'lr': args['lr'], 'lr_backbone': lr_backbone, 'backbone' : backbone, 'num_queries': 1}
policy_config = {'lr': args['lr'], 'lr_backbone': lr_backbone, 'backbone' : backbone, 'num_queries': 1,
'camera_names': camera_names,}
else:
raise NotImplementedError

config = {
'num_epochs': num_epochs,
'ckpt_dir': ckpt_dir,
'episode_len': episode_len,
'state_dim': state_dim,
'lr': args['lr'],
'real_robot': 'TBD',
'policy_class': policy_class,
'onscreen_render': onscreen_render,
'policy_config': policy_config,
'task_name': task_name,
'seed': args['seed'],
'temporal_agg': args['temporal_agg']
'temporal_agg': args['temporal_agg'],
'camera_names': camera_names,
'real_robot': not is_sim
}

train_dataloader, val_dataloader, stats, is_sim = load_data(dataset_dir, num_episodes, batch_size_train, batch_size_val)

if is_sim:
policy_config['camera_names'] = SIM_CAMERA_NAMES
config['camera_names'] = SIM_CAMERA_NAMES
config['real_robot'] = False
if task_name == 'transfer_cube':
config['episode_len'] = SIM_EPISODE_LEN_TRANSFER_CUBE
elif task_name == 'insertion':
config['episode_len'] = SIM_EPISODE_LEN_INSERTION
else:
policy_config['camera_names'] = CAMERA_NAMES
config['camera_names'] = CAMERA_NAMES
config['real_robot'] = True
config['episode_len'] = EPISODE_LEN
train_dataloader, val_dataloader, stats, _ = load_data(dataset_dir, num_episodes, camera_names, batch_size_train, batch_size_val)

if is_eval:
ckpt_names = [f'policy_best.ckpt']
Expand Down Expand Up @@ -159,7 +160,7 @@ def eval_bc(config, ckpt_name, save_episode=True):
max_timesteps = config['episode_len']
task_name = config['task_name']
temporal_agg = config['temporal_agg']
onscreen_cam = 'main'
onscreen_cam = 'angle'

# load policy and stats
ckpt_path = os.path.join(ckpt_dir, ckpt_name)
Expand All @@ -178,8 +179,8 @@ def eval_bc(config, ckpt_name, save_episode=True):

# load environment
if real_robot:
from scripts.utils import move_grippers # requires aloha
from scripts.real_env import make_real_env # requires aloha
from aloha_scripts.robot_utils import move_grippers # requires aloha
from aloha_scripts.real_env import make_real_env # requires aloha
env = make_real_env(init_node=True)
env_max_reward = 0
else:
Expand All @@ -200,12 +201,11 @@ def eval_bc(config, ckpt_name, save_episode=True):
for rollout_id in range(num_rollouts):
rollout_id += 0
### set task
if task_name == 'transfer_cube':
if 'sim_transfer_cube' in task_name:
BOX_POSE[0] = sample_box_pose() # used in sim reset
elif task_name == 'insertion':
elif 'sim_insertion' in task_name:
BOX_POSE[0] = np.concatenate(sample_insertion_pose()) # used in sim reset
else:
raise NotImplementedError

ts = env.reset()

### onscreen render
Expand Down Expand Up @@ -417,7 +417,6 @@ def plot_history(train_history, validation_history, num_epochs, ckpt_dir, seed):
parser = argparse.ArgumentParser()
parser.add_argument('--eval', action='store_true')
parser.add_argument('--onscreen_render', action='store_true')
parser.add_argument('--dataset_dir', action='store', type=str, help='dataset_dir', required=True)
parser.add_argument('--ckpt_dir', action='store', type=str, help='ckpt_dir', required=True)
parser.add_argument('--policy_class', action='store', type=str, help='policy_class, capitalize', required=True)
parser.add_argument('--task_name', action='store', type=str, help='task_name', required=True)
Expand Down
35 changes: 19 additions & 16 deletions record_sim_episodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
import matplotlib.pyplot as plt
import h5py_cache

from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN
from constants import SIM_EPISODE_LEN_TRANSFER_CUBE, SIM_EPISODE_LEN_INSERTION, SIM_CAMERA_NAMES
from constants import PUPPET_GRIPPER_POSITION_NORMALIZE_FN, SIM_TASK_CONFIGS
from ee_sim_env import make_ee_sim_env
from sim_env import make_sim_env, BOX_POSE
from scripted_policy import PickAndTransferPolicy, InsertionPolicy
Expand All @@ -29,21 +28,24 @@ def main(args):
num_episodes = args['num_episodes']
onscreen_render = args['onscreen_render']
inject_noise = False
render_cam_name = 'angle'

if not os.path.isdir(dataset_dir):
os.makedirs(dataset_dir, exist_ok=True)

if task_name == 'transfer_cube':
episode_len = SIM_TASK_CONFIGS[task_name]['episode_len']
camera_names = SIM_TASK_CONFIGS[task_name]['camera_names']
if task_name == 'sim_transfer_cube_scripted':
policy_cls = PickAndTransferPolicy
episode_len = SIM_EPISODE_LEN_TRANSFER_CUBE
elif task_name == 'insertion':
elif task_name == 'sim_insertion_scripted':
policy_cls = InsertionPolicy
episode_len = SIM_EPISODE_LEN_INSERTION
else:
raise NotImplementedError

success = []
for episode_idx in range(num_episodes):
print(f'{episode_idx=}')
print('Rollout out EE space scripted policy')
# setup the environment
env = make_ee_sim_env(task_name)
ts = env.reset()
Expand All @@ -52,14 +54,14 @@ def main(args):
# setup plotting
if onscreen_render:
ax = plt.subplot()
plt_img = ax.imshow(ts.observation['images']['main'])
plt_img = ax.imshow(ts.observation['images'][render_cam_name])
plt.ion()
for step in range(episode_len):
action = policy(ts)
ts = env.step(action)
episode.append(ts)
if onscreen_render:
plt_img.set_data(ts.observation['images']['main'])
plt_img.set_data(ts.observation['images'][render_cam_name])
plt.pause(0.002)
plt.close()

Expand Down Expand Up @@ -87,7 +89,7 @@ def main(args):
del policy

# setup the environment
print(f'====== Start Replaying ======')
print('Replaying joint commands')
env = make_sim_env(task_name)
BOX_POSE[0] = subtask_info # make sure the sim_env has the same object configurations as ee_sim_env
ts = env.reset()
Expand All @@ -96,14 +98,14 @@ def main(args):
# setup plotting
if onscreen_render:
ax = plt.subplot()
plt_img = ax.imshow(ts.observation['images']['main'])
plt_img = ax.imshow(ts.observation['images'][render_cam_name])
plt.ion()
for t in range(len(joint_traj)): # note: this will increase episode length by 1
action = joint_traj[t]
ts = env.step(action)
episode_replay.append(ts)
if onscreen_render:
plt_img.set_data(ts.observation['images']['main'])
plt_img.set_data(ts.observation['images'][render_cam_name])
plt.pause(0.02)

episode_return = np.sum([ts.reward for ts in episode_replay[1:]])
Expand All @@ -121,7 +123,7 @@ def main(args):
For each timestep:
observations
- images
- main (480, 640, 3) 'uint8'
- each_cam_name (480, 640, 3) 'uint8'
- qpos (14,) 'float64'
- qvel (14,) 'float64'
Expand All @@ -133,7 +135,7 @@ def main(args):
'/observations/qvel': [],
'/action': [],
}
for cam_name in SIM_CAMERA_NAMES:
for cam_name in camera_names:
data_dict[f'/observations/images/{cam_name}'] = []

# because the replaying, there will be eps_len + 1 actions and eps_len + 2 timesteps
Expand All @@ -150,7 +152,7 @@ def main(args):
data_dict['/observations/qpos'].append(ts.observation['qpos'])
data_dict['/observations/qvel'].append(ts.observation['qvel'])
data_dict['/action'].append(action)
for cam_name in SIM_CAMERA_NAMES:
for cam_name in camera_names:
data_dict[f'/observations/images/{cam_name}'].append(ts.observation['images'][cam_name])

# HDF5
Expand All @@ -161,8 +163,9 @@ def main(args):
root.attrs['sim'] = True
obs = root.create_group('observations')
image = obs.create_group('images')
cam_main = image.create_dataset('main', (max_timesteps, 480, 640, 3), dtype='uint8',
chunks=(1, 480, 640, 3), )
for cam_name in camera_names:
_ = image.create_dataset(cam_name, (max_timesteps, 480, 640, 3), dtype='uint8',
chunks=(1, 480, 640, 3), )
# compression='gzip',compression_opts=2,)
# compression=32001, compression_opts=(0, 0, 0, 0, 9, 1, 1), shuffle=False)
qpos = obs.create_dataset('qpos', (max_timesteps, 14))
Expand Down
Loading

0 comments on commit 5a33ee8

Please sign in to comment.