Skip to content

Commit

Permalink
Merge branch 'main' of github.com:MarkFzp/act-plus-plus into main
Browse files Browse the repository at this point in the history
  • Loading branch information
MarkFzp committed Nov 21, 2023
2 parents de1c002 + c593d17 commit ad52735
Show file tree
Hide file tree
Showing 8 changed files with 771 additions and 29 deletions.
53 changes: 52 additions & 1 deletion commands.txt
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ CUDA_VISIBLE_DEVICES=0 python3 imitate_episodes.py \
--ckpt_dir /scr/tonyzhao/train_logs/cube_scripted \
--policy_class ACT --kl_weight 10 --chunk_size 50 \
--hidden_dim 512 --batch_size 12 --dim_feedforward 3200 --lr 1e-5 --seed 0 \
--num_steps 100000 --eval_every 2000 --validate_every 2000 --save_every 2000
--num_steps 100000 --eval_every 2000 --validate_every 2000 --save_every 2000 --no_encoder


# launch experiment on all data
Expand All @@ -70,8 +70,59 @@ CUDA_VISIBLE_DEVICES=0 python3 imitate_episodes.py \
--ckpt_dir /scr/tonyzhao/train_logs/cube_scripted_mirror \
--policy_class ACT --kl_weight 10 --chunk_size 50 \
--hidden_dim 512 --batch_size 12 --dim_feedforward 3200 --lr 1e-5 --seed 0 \
--num_steps 100000 --eval_every 2000 --validate_every 2000 --save_every 2000 --no_encoder


####### DIFFUSION POLICY

- first install https://github.com/ARISE-Initiative/robomimic/tree/r2d2 (note the r2d2 branch)
- on top of it pip install the current repo requirements


conda activate mobile
export MUJOCO_GL=egl
cd /home/tonyzhao/Research/act-plus-plus
CUDA_VISIBLE_DEVICES=0 python3 imitate_episodes.py \
--task_name sim_transfer_cube_scripted \
--ckpt_dir /scr/tonyzhao/train_logs/cube_scripted_diffusion_sweep_0 \
--policy_class Diffusion --chunk_size 32 \
--batch_size 32 --lr 1e-5 --seed 0 \
--num_steps 100000 --eval_every 2000 --validate_every 2000 --save_every 2000


conda activate mobile
export MUJOCO_GL=egl
cd /home/tonyzhao/Research/act-plus-plus
CUDA_VISIBLE_DEVICES=1 python3 imitate_episodes.py \
--task_name sim_transfer_cube_scripted \
--ckpt_dir /scr/tonyzhao/train_logs/cube_scripted_diffusion_sweep_1 \
--policy_class Diffusion --chunk_size 16 \
--batch_size 32 --lr 1e-5 --seed 0 \
--num_steps 100000 --eval_every 2000 --validate_every 2000 --save_every 2000


# above are all 100 train diffusion steps, 1e-5

conda activate mobile
export MUJOCO_GL=egl
cd /home/tonyzhao/Research/act-plus-plus
CUDA_VISIBLE_DEVICES=1 python3 imitate_episodes.py \
--task_name sim_transfer_cube_scripted \
--ckpt_dir /scr/tonyzhao/train_logs/cube_scripted_diffusion_sweep_2_50step_1e-4 \
--policy_class Diffusion --chunk_size 32 \
--batch_size 32 --lr 1e-4 --seed 0 \
--num_steps 100000 --eval_every 2000 --validate_every 2000 --save_every 2000



---------------------------------------------------------------------------------------

NOTE: chunk size cannot be any number, try before launching
TODO: Add history, EMA at test time

conda activate mobile
cd /home/tonyzhao/Research/act-plus-plus
CUDA_VISIBLE_DEVICES=1 python3 train_actuator_network.py



4 changes: 2 additions & 2 deletions compress_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def save_videos(video, dt, video_path=None):
images = []
for cam_name in cam_names:
image = image_dict[cam_name]
# image = image[:, :, [2, 1, 0]] # swap B and R channel
image = image[:, :, [2, 1, 0]] # swap B and R channel
images.append(image)
images = np.concatenate(images, axis=1)
out.write(images)
Expand All @@ -119,7 +119,7 @@ def save_videos(video, dt, video_path=None):
out = cv2.VideoWriter(video_path, cv2.VideoWriter_fourcc(*'mp4v'), fps, (w, h))
for t in range(n_frames):
image = all_cam_videos[t]
# image = image[:, :, [2, 1, 0]] # swap B and R channel
image = image[:, :, [2, 1, 0]] # swap B and R channel
out.write(image)
out.release()
print(f'Saved video to: {video_path}')
Expand Down
4 changes: 4 additions & 0 deletions detr/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,10 @@ def get_args_parser():
parser.add_argument('--resume_ckpt_path', action='store', type=str, help='load_ckpt_path', required=False)
parser.add_argument('--no_encoder', action='store_true')
parser.add_argument('--skip_mirrored_data', action='store_true')
parser.add_argument('--actuator_network_dir', action='store', type=str, help='actuator_network_dir', required=False)
parser.add_argument('--history_len', action='store', type=int)
parser.add_argument('--future_len', action='store', type=int)
parser.add_argument('--prediction_len', action='store', type=int)

return parser

Expand Down
98 changes: 89 additions & 9 deletions imitate_episodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from utils import load_data # data functions
from utils import sample_box_pose, sample_insertion_pose # robot functions
from utils import compute_dict_mean, set_seed, detach_dict, calibrate_linear_vel, postprocess_base_action # helper functions
from policy import ACTPolicy, CNNMLPPolicy
from policy import ACTPolicy, CNNMLPPolicy, DiffusionPolicy
from visualize_episodes import save_videos

from detr.models.latent_model import Latent_Model_Transformer
Expand Down Expand Up @@ -80,12 +80,32 @@ def main(args):
'action_dim': 16,
'no_encoder': args['no_encoder'],
}
elif policy_class == 'Diffusion':

policy_config = {'lr': args['lr'],
'camera_names': camera_names,
'action_dim': 16,
'observation_horizon': 1,
'action_horizon': 8,
'prediction_horizon': args['chunk_size'],
'num_queries': args['chunk_size'],
'num_inference_timesteps': 10,
'ema_power': 0.75,
'vq': False,
}
elif policy_class == 'CNNMLP':
policy_config = {'lr': args['lr'], 'lr_backbone': lr_backbone, 'backbone' : backbone, 'num_queries': 1,
'camera_names': camera_names,}
else:
raise NotImplementedError

actuator_config = {
'actuator_network_dir': args['actuator_network_dir'],
'history_len': args['history_len'],
'future_len': args['future_len'],
'prediction_len': args['prediction_len'],
}

config = {
'num_steps': num_steps,
'eval_every': eval_every,
Expand All @@ -104,15 +124,16 @@ def main(args):
'temporal_agg': args['temporal_agg'],
'camera_names': camera_names,
'real_robot': not is_sim,
'load_pretrain': args['load_pretrain']
'load_pretrain': args['load_pretrain'],
'actuator_config': actuator_config,
}

if not os.path.isdir(ckpt_dir):
os.makedirs(ckpt_dir)
config_path = os.path.join(ckpt_dir, 'config.pkl')
expr_name = ckpt_dir.split('/')[-1]
if not is_eval:
wandb.init(project="mobile-aloha", reinit=True, entity="mobile-aloha", name=expr_name)
wandb.init(project="mobile-aloha2", reinit=True, entity="mobile-aloha2", name=expr_name)
wandb.config.update(config)
with open(config_path, 'wb') as f:
pickle.dump(config, f)
Expand All @@ -129,7 +150,7 @@ def main(args):
print()
exit()

train_dataloader, val_dataloader, stats, _ = load_data(dataset_dir, name_filter, camera_names, batch_size_train, batch_size_val, args['chunk_size'], args['skip_mirrored_data'], config['load_pretrain'])
train_dataloader, val_dataloader, stats, _ = load_data(dataset_dir, name_filter, camera_names, batch_size_train, batch_size_val, args['chunk_size'], args['skip_mirrored_data'], config['load_pretrain'], policy_class)

# save dataset stats
stats_path = os.path.join(ckpt_dir, f'dataset_stats.pkl')
Expand All @@ -151,6 +172,8 @@ def make_policy(policy_class, policy_config):
policy = ACTPolicy(policy_config)
elif policy_class == 'CNNMLP':
policy = CNNMLPPolicy(policy_config)
elif policy_class == 'Diffusion':
policy = DiffusionPolicy(policy_config)
else:
raise NotImplementedError
return policy
Expand All @@ -161,6 +184,8 @@ def make_optimizer(policy_class, policy):
optimizer = policy.configure_optimizers()
elif policy_class == 'CNNMLP':
optimizer = policy.configure_optimizers()
elif policy_class == 'Diffusion':
optimizer = policy.configure_optimizers()
else:
raise NotImplementedError
return optimizer
Expand Down Expand Up @@ -190,6 +215,8 @@ def eval_bc(config, ckpt_name, save_episode=True, num_rollouts=50):
temporal_agg = config['temporal_agg']
onscreen_cam = 'angle'
vq = config['policy_config']['vq']
actuator_config = config['actuator_config']
use_actuator_net = actuator_config['actuator_network_dir'] is not None

# load policy and stats
ckpt_path = os.path.join(ckpt_dir, ckpt_name)
Expand All @@ -212,9 +239,35 @@ def eval_bc(config, ckpt_name, save_episode=True, num_rollouts=50):
stats_path = os.path.join(ckpt_dir, f'dataset_stats.pkl')
with open(stats_path, 'rb') as f:
stats = pickle.load(f)
if use_actuator_net:
prediction_len = actuator_config['prediction_len']
future_len = actuator_config['future_len']
history_len = actuator_config['history_len']
actuator_network_dir = actuator_config['actuator_network_dir']

from train_actuator_network import ActuatorNetwork
actuator_network = ActuatorNetwork(prediction_len)
actuator_network_path = os.path.join(actuator_network_dir, 'actuator_net_last.ckpt')
loading_status = actuator_network.load_state_dict(torch.load(actuator_network_path))
actuator_network.eval()
actuator_network.cuda()
print(f'Loaded actuator network from: {actuator_network_path}, {loading_status}')

actuator_stats_path = os.path.join(actuator_network_dir, 'actuator_net_stats.pkl')
with open(actuator_stats_path, 'rb') as f:
actuator_stats = pickle.load(f)

actuator_unnorm = lambda x: x * actuator_stats['action_std'] + actuator_stats['action_mean']
actuator_norm = lambda x: (x - actuator_stats['action_mean']) / actuator_stats['action_std']
def collect_base_action(all_actions, norm_episode_all_base_actions):
post_processed_actions = post_process(all_actions.squeeze(0).cpu().numpy())
norm_episode_all_base_actions += actuator_norm(post_processed_actions[:, -2:]).tolist()

pre_process = lambda s_qpos: (s_qpos - stats['qpos_mean']) / stats['qpos_std']
post_process = lambda a: a * stats['action_std'] + stats['action_mean']
if policy_class == 'Diffusion':
post_process = lambda a: ((a + 1) / 2) * (stats['action_max'] - stats['action_min']) + stats['action_min']
else:
post_process = lambda a: a * stats['action_std'] + stats['action_mean']

# load environment
if real_robot:
Expand Down Expand Up @@ -263,6 +316,7 @@ def eval_bc(config, ckpt_name, save_episode=True, num_rollouts=50):
qpos_list = []
target_qpos_list = []
rewards = []
norm_episode_all_base_actions = [actuator_norm(np.zeros(history_len, 2)).tolist()]
with torch.inference_mode():
for t in range(max_timesteps):
### update onscreen render and wait for DT
Expand Down Expand Up @@ -296,6 +350,8 @@ def eval_bc(config, ckpt_name, save_episode=True, num_rollouts=50):
else:
# e()
all_actions = policy(qpos, curr_image)
if use_actuator_net:
collect_base_action(all_actions, norm_episode_all_base_actions)
if temporal_agg:
all_time_actions[[t], t:t+num_queries] = all_actions
actions_for_curr_step = all_time_actions[:, t]
Expand All @@ -308,18 +364,38 @@ def eval_bc(config, ckpt_name, save_episode=True, num_rollouts=50):
raw_action = (actions_for_curr_step * exp_weights).sum(dim=0, keepdim=True)
else:
raw_action = all_actions[:, t % query_frequency]
elif config['policy_class'] == "Diffusion":
if t % query_frequency == 0:
all_actions = policy(qpos, curr_image)
if use_actuator_net:
collect_base_action(all_actions, norm_episode_all_base_actions)
raw_action = all_actions[:, t % query_frequency]
elif config['policy_class'] == "CNNMLP":
raw_action = policy(qpos, curr_image)
all_actions = raw_action.unsqueeze(0)
if use_actuator_net:
collect_base_action(all_actions, norm_episode_all_base_actions)
else:
raise NotImplementedError

### post-process actions
raw_action = raw_action.squeeze(0).cpu().numpy()
action = post_process(raw_action)
target_qpos = action[:-2]
base_action = action[-2:]
# base_action = calibrate_linear_vel(base_action, c=0.19)
# base_action = postprocess_base_action(base_action)

if use_actuator_net:
assert(not temporal_agg)
if t % prediction_len == 0:
offset_start_ts = t + history_len
actuator_net_in = np.array(norm_episode_all_base_actions[offset_start_ts - history_len: offset_start_ts + future_len])
actuator_net_in = torch.from_numpy(actuator_net_in).float().unsqueeze(dim=0).cuda()
pred = actuator_network(actuator_net_in)
base_action_chunk = actuator_unnorm(pred.detach().cpu().numpy()[0])
base_action = base_action_chunk[t % prediction_len]
else:
base_action = action[-2:]
# base_action = calibrate_linear_vel(base_action, c=0.19)
# base_action = postprocess_base_action(base_action)

### step the environment
if real_robot:
Expand Down Expand Up @@ -418,7 +494,7 @@ def train_bc(train_dataloader, val_dataloader, config):
if epoch_val_loss < min_val_loss:
min_val_loss = epoch_val_loss
best_ckpt_info = (step, min_val_loss, deepcopy(policy.state_dict()))
for k in validation_summary.keys():
for k in list(validation_summary.keys()):
validation_summary[f'val_{k}'] = validation_summary.pop(k)
wandb.log(validation_summary, step=step)
print(f'Val loss: {epoch_val_loss:.5f}')
Expand Down Expand Up @@ -487,6 +563,10 @@ def repeater(data_loader):
parser.add_argument('--save_every', action='store', type=int, default=500, help='save_every', required=False)
parser.add_argument('--resume_ckpt_path', action='store', type=str, help='resume_ckpt_path', required=False)
parser.add_argument('--skip_mirrored_data', action='store_true')
parser.add_argument('--actuator_network_dir', action='store', type=str, help='actuator_network_dir', required=False)
parser.add_argument('--history_len', action='store', type=int)
parser.add_argument('--future_len', action='store', type=int)
parser.add_argument('--prediction_len', action='store', type=int)

# for ACT
parser.add_argument('--kl_weight', action='store', type=int, help='KL Weight', required=False)
Expand Down
Loading

0 comments on commit ad52735

Please sign in to comment.