Skip to content

Commit

Permalink
fix keypoint bug on different shapes
Browse files Browse the repository at this point in the history
  • Loading branch information
yaoxingcheng committed Dec 11, 2024
1 parent 521b82c commit 8c05746
Show file tree
Hide file tree
Showing 9 changed files with 32 additions and 13 deletions.
2 changes: 1 addition & 1 deletion config/image_pusht_dp_tf_hgr_20pc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ task:
type: rgb
task_name: pusht_image
training:
checkpoint_every: 50
checkpoint_every: 20
debug: false
device: cuda:0
gradient_accumulate_every: 1
Expand Down
3 changes: 3 additions & 0 deletions diffusion_policy/env/pusht/pusht_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,9 @@ def seed(self, seed=None):
def _handle_collision(self, arbiter, space, data):
self.n_contact_points += len(arbiter.contact_point_set.points)

def set_block_shape(self, shape):
self.block_shape = shape

def _set_state(self, state):
if isinstance(state, np.ndarray):
state = state.tolist()
Expand Down
15 changes: 12 additions & 3 deletions diffusion_policy/env/pusht/pusht_keypoints_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
from diffusion_policy.env.pusht.pymunk_keypoint_manager import PymunkKeypointManager
import numpy as np

KP_MANAGER_DICT = {
"tee": PymunkKeypointManager.create_from_pusht_env(PushTEnv(), block_shape="tee"),
"gamma": PymunkKeypointManager.create_from_pusht_env(PushTEnv(), block_shape="gamma"),
"al": PymunkKeypointManager.create_from_pusht_env(PushTEnv(), block_shape="al"),
"vee": PymunkKeypointManager.create_from_pusht_env(PushTEnv(), block_shape="vee"),
}
class PushTKeypointsEnv(PushTEnv):
def __init__(self,
legacy=False,
Expand Down Expand Up @@ -63,9 +69,8 @@ def __init__(self,
self.keypoint_visible_rate = keypoint_visible_rate
self.agent_keypoints = agent_keypoints
self.draw_keypoints = draw_keypoints
self.kp_manager = PymunkKeypointManager(
local_keypoint_map=local_keypoint_map,
color_map=color_map)
self.kp_manager = KP_MANAGER_DICT["tee"]

self.draw_kp_map = None

@classmethod
Expand All @@ -74,6 +79,10 @@ def genenerate_keypoint_manager_params(cls):
kp_manager = PymunkKeypointManager.create_from_pusht_env(env)
kp_kwargs = kp_manager.kwargs
return kp_kwargs

def set_block_shape(self, shape):
self.block_shape = shape
self.kp_manager = KP_MANAGER_DICT[shape]

def _get_obs(self):
# get keypoints
Expand Down
11 changes: 9 additions & 2 deletions diffusion_policy/env/pusht/pymunk_keypoint_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def kwargs(self):
}

@classmethod
def create_from_pusht_env(cls, env, n_block_kps=9, n_agent_kps=3, seed=0, **kwargs):
def create_from_pusht_env(cls, env, n_block_kps=9, n_agent_kps=3, seed=0, block_shape="tee", **kwargs):
rng = np.random.default_rng(seed=seed)
local_keypoint_map = dict()
for name in ['block','agent']:
Expand All @@ -60,7 +60,14 @@ def create_from_pusht_env(cls, env, n_block_kps=9, n_agent_kps=3, seed=0, **kwar
self.agent = obj = self.add_circle((256, 400), 15)
n_kps = n_agent_kps
else:
self.block = obj = self.add_tee((256, 300), 0)
if block_shape == "tee":
self.block = obj = self.add_tee((256, 300), 0)
elif block_shape == "gamma":
self.block = obj = self.add_gamma((256, 300), 0)
elif block_shape == "al":
self.block = obj = self.add_al((256, 300), 0)
elif block_shape == "vee":
self.block = obj = self.add_vee((256, 300), 0)
n_kps = n_block_kps

self.screen = pygame.Surface((512,512))
Expand Down
2 changes: 1 addition & 1 deletion diffusion_policy/env_runner/pusht_image_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def init_fn(env, seed=seed, enable_render=enable_render, shape=shape_name):

# set shape
assert isinstance(env.env.env, PushTImageEnv)
env.env.env.block_shape = shape
env.env.env.set_block_shape(shape)

env_seeds.append(seed)
env_prefixs.append(f'test/{shape_name}_')
Expand Down
2 changes: 1 addition & 1 deletion diffusion_policy/env_runner/pusht_keypoints_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ def init_fn(env, seed=seed, enable_render=enable_render, shape=shape_name):

# set shape
assert isinstance(env.env.env, PushTKeypointsEnv)
env.env.env.block_shape = shape
env.env.env.set_block_shape(shape)

env_seeds.append(seed)
env_prefixs.append(f'test/{shape_name}_')
Expand Down
2 changes: 1 addition & 1 deletion eval.sh
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# python eval.py --checkpoint data/0550-test_mean_score=0.969.ckpt --output_dir data/pusht_eval_output --device cuda:0
# python eval.py --checkpoint /mnt/diffusion_policy/data/outputs/2024.11.25/00.43.25_train_diffusion_unet_hybrid_pusht_image/checkpoints/epoch=0100-test_mean_score=0.804.ckpt --output_dir data/pusht_eval_output --device cuda:0
# python eval.py --checkpoint /mnt/diffusion_policy/data/outputs/2024.11.25/02.30.25_train_diffusion_unet_hybrid_pusht_image/checkpoints/epoch=0350-test_mean_score=0.534.ckpt --output_dir data/pusht_eval_output --device cuda:0
python eval.py --checkpoint /local2/xingcheng/diffusion_policy/data/outputs/2024.11.25/00.19.26_train_diffusion_transformer_lowdim_pusht_lowdim/checkpoints/epoch=1000-test_mean_score=0.845.ckpt --output_dir data/pusht_all_shape_eval_output --device cuda:0
python eval.py --checkpoint /local2/xingcheng/diffusion_policy/data/outputs/2024.11.25/00.19.26_train_diffusion_transformer_lowdim_pusht_lowdim/checkpoints/epoch=1000-test_mean_score=0.845.ckpt --output_dir data/pusht_all_shape_kpbugfixed_eval_output --device cuda:7
2 changes: 1 addition & 1 deletion hgr.sh
Original file line number Diff line number Diff line change
@@ -1 +1 @@
python hgr_pusht.py -o data/hgr_pusht_1000epoch_base_20.zarr -c /local2/xingcheng/diffusion_policy/data/outputs/2024.11.25/00.19.26_train_diffusion_transformer_lowdim_pusht_lowdim/checkpoints/epoch=1000-test_mean_score=0.845.ckpt -d cuda:0
python hgr_pusht.py -o data/hgr_pusht_1000epoch_base_20_kpright.zarr -c /local2/xingcheng/diffusion_policy/data/outputs/2024.11.25/00.19.26_train_diffusion_transformer_lowdim_pusht_lowdim/checkpoints/epoch=1000-test_mean_score=0.845.ckpt -d cuda:7
6 changes: 3 additions & 3 deletions hgr_pusht.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ def main(output, checkpoint, device, num_trajs_per_shape, reward_thres):
episode = list()
# set seed for env
env.seed(seed)
env.block_shape = shape
env.set_block_shape(shape)
menv.seed(seed)
menv.env.block_shape = shape
menv.env.set_block_shape(shape)
# reset env and get observations (including info and render for recording)
obs = env.reset()
info = env._get_info()
Expand Down Expand Up @@ -176,7 +176,7 @@ def main(output, checkpoint, device, num_trajs_per_shape, reward_thres):

# First transform the initial position of the block
env.seed(seed)
env.block_shape = shape
env.set_block_shape(shape)
obs = env.reset()
init_info = env._get_info()
block_pose = init_info['block_pose']
Expand Down

0 comments on commit 8c05746

Please sign in to comment.