Skip to content

Commit

Permalink
batched HMR inference.
Browse files Browse the repository at this point in the history
  • Loading branch information
brjathu committed Jul 2, 2023
1 parent ebc061f commit 2b77b4d
Show file tree
Hide file tree
Showing 3 changed files with 71 additions and 58 deletions.
2 changes: 1 addition & 1 deletion phalp/models/backbones/resnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def forward(self, x):
x4 = self.layer4(x3)

if(self.cfg.MODEL.BACKBONE.MASK_TYPE=="feat"):
x5 = copy.deepcopy(x4)
x5 = x4.clone()
x5 = x5*x_
return x5, [x1,x2,x3,x4]
else:
Expand Down
8 changes: 4 additions & 4 deletions phalp/models/hmar/hmar.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,14 @@ def get_3d_parameters(self, pred_smpl_params, pred_cam, center=np.array([128, 12
pred_cam_t = torch.stack([pred_cam[:,1], pred_cam[:,2], 2*focal_length[:, 0]/(pred_cam[:,0]*torch.tensor(scale[:, 0], dtype=dtype, device=device) + 1e-9)], dim=1)
pred_cam_t[:, :2] += torch.tensor(center-img_size/2., dtype=dtype, device=device) * pred_cam_t[:, [2]] / focal_length

zeros_ = torch.zeros(batch_size, 1, 3).cuda()
zeros_ = torch.zeros(batch_size, 1, 3).to(device)
pred_joints = torch.cat((pred_joints, zeros_), 1)

camera_center = torch.zeros(batch_size, 2)
pred_keypoints_2d_smpl = perspective_projection(pred_joints, rotation=torch.eye(3,).unsqueeze(0).expand(batch_size, -1, -1).cuda(),
translation=pred_cam_t.cuda(),
pred_keypoints_2d_smpl = perspective_projection(pred_joints, rotation=torch.eye(3,).unsqueeze(0).expand(batch_size, -1, -1).to(device),
translation=pred_cam_t.to(device),
focal_length=focal_length / img_size,
camera_center=camera_center.cuda())
camera_center=camera_center.to(device))

pred_keypoints_2d_smpl = (pred_keypoints_2d_smpl+0.5)*img_size

Expand Down
119 changes: 66 additions & 53 deletions phalp/trackers/PHALP.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,11 +194,7 @@ def track(self):
pred_bbox, pred_masks, pred_scores, pred_classes, gt_tids, gt_annots = self.get_detections(image_frame, frame_name, t_, additional_data, measurments)

############ HMAR ##############
detections = []
for bbox, mask, score, cls_id, gt_tid, gt_ann in zip(pred_bbox, pred_masks, pred_scores, pred_classes, gt_tids, gt_annots):
if (bbox[2]-bbox[0]<self.cfg.phalp.small_w or bbox[3]-bbox[1]<self.cfg.phalp.small_h) and len(gt_ann)==0: continue
detection_data = self.get_human_features(image_frame, mask, bbox, score, frame_name, cls_id, t_, measurments, gt_tid, gt_ann)
detections.append(Detection(detection_data))
detections = self.get_human_features(image_frame, pred_masks, pred_bbox, pred_scores, frame_name, pred_classes, t_, measurments, gt_tids, gt_annots)

############ tracking ##############
self.tracker.predict()
Expand Down Expand Up @@ -358,71 +354,88 @@ def get_croped_image(self, image, bbox, seg_mask):
return masked_image, center_, scale_, rles

def get_human_features(self, image, seg_mask, bbox, score, frame_name, cls_id, t_, measurments, gt=1, ann=None):

NPEOPLE = len(score)
BS = NPEOPLE

img_height, img_width, new_image_size, left, top = measurments
ratio = 1.0/int(new_image_size)*self.cfg.render.res
masked_image, center_, scale_, rles = self.get_croped_image(image, bbox, seg_mask)

masked_image_list = []
center_list = []
scale_list = []
rles_list = []
for p_ in range(NPEOPLE):
masked_image, center_, scale_, rles = self.get_croped_image(image, bbox[p_], seg_mask[p_])
masked_image_list.append(masked_image)
center_list.append(center_)
scale_list.append(scale_)
rles_list.append(rles)

masked_image_list = torch.stack(masked_image_list, dim=0)

with torch.no_grad():
extra_args = {}
hmar_out = self.HMAR(masked_image.unsqueeze(0).cuda(), **extra_args)
hmar_out = self.HMAR(masked_image_list.cuda(), **extra_args)
uv_vector = hmar_out['uv_vector']
appe_embedding = self.HMAR.autoencoder_hmar(uv_vector, en=True)
appe_embedding = appe_embedding.view(1, -1)
appe_embedding = appe_embedding.view(appe_embedding.shape[0], -1)
pred_smpl_params, pred_joints_2d, pred_joints, pred_cam = self.HMAR.get_3d_parameters(hmar_out['pose_smpl'], hmar_out['pred_cam'],
center=(center_ + [left, top])*ratio,
center=(np.array(center_list) + np.array([left, top]))*ratio,
img_size=self.cfg.render.res,
scale=np.reshape(np.array([max(scale_)]), (1, 1))*ratio)

pred_smpl_params = {k:v[0].cpu().numpy() for k,v in pred_smpl_params.items()}
scale=np.max(np.array(scale_list), axis=1, keepdims=True)*ratio)
pred_smpl_params = [{k:v[i].cpu().numpy() for k,v in pred_smpl_params.items()} for i in range(BS)]

if(self.cfg.phalp.pose_distance=="joints"):
pose_embedding = pred_joints[0].cpu().view(1, -1)
pose_embedding = pred_joints.cpu().view(BS, -1)
elif(self.cfg.phalp.pose_distance=="smpl"):
pose_embedding = smpl_to_pose_camera_vector(pred_smpl_params, pred_cam)
pose_embedding = torch.from_numpy(pose_embedding)
pose_embedding = []
for i in range(BS):
pose_embedding_ = smpl_to_pose_camera_vector(pred_smpl_params[i], pred_cam[i])
pose_embedding.append(torch.from_numpy(pose_embedding_[0]))
pose_embedding = torch.stack(pose_embedding, dim=0)
else:
raise ValueError("Unknown pose distance")

pred_joints_2d_ = pred_joints_2d.reshape(-1,)/self.cfg.render.res
pred_cam_ = pred_cam.view(-1,)
pred_joints_2d_ = pred_joints_2d.reshape(BS,-1)/self.cfg.render.res
pred_cam_ = pred_cam.view(BS, -1)
pred_joints_2d_.contiguous()
pred_cam_.contiguous()
loca_embedding = torch.cat((pred_joints_2d_, pred_cam_, pred_cam_, pred_cam_), 0)
loca_embedding = torch.cat((pred_joints_2d_, pred_cam_, pred_cam_, pred_cam_), 1)

# keeping it here for legacy reasons (T3DP), but it is not used.
full_embedding = torch.cat((appe_embedding[0].cpu(), pose_embedding[0], loca_embedding.cpu()), 0)

detection_data = {
"bbox" : np.array([bbox[0], bbox[1], (bbox[2] - bbox[0]), (bbox[3] - bbox[1])]),
"mask" : rles,
"conf" : score,

"appe" : appe_embedding[0].cpu().numpy(),
"pose" : pose_embedding[0].numpy(),
"loca" : loca_embedding.cpu().numpy(),
"uv" : uv_vector[0].cpu().numpy(),

"embedding" : full_embedding,
"center" : center_,
"scale" : scale_,
"smpl" : pred_smpl_params,
"camera" : pred_cam_.cpu().numpy(),
"camera_bbox" : hmar_out['pred_cam'][0].cpu().numpy(),
"3d_joints" : pred_joints[0].cpu().numpy(),
"2d_joints" : pred_joints_2d_.cpu().numpy(),
"size" : [img_height, img_width],
"img_path" : frame_name,
"img_name" : frame_name.split('/')[-1] if isinstance(frame_name, str) else None,
"class_name" : cls_id,
"time" : t_,

"ground_truth" : gt,
"annotations" : ann
}

return detection_data

full_embedding = torch.cat((appe_embedding.cpu(), pose_embedding, loca_embedding.cpu()), 1)

detection_data_list = []
for p_ in range(NPEOPLE):
detection_data = {
"bbox" : np.array([bbox[p_][0], bbox[p_][1], (bbox[p_][2] - bbox[p_][0]), (bbox[p_][3] - bbox[p_][1])]),
"mask" : rles_list[p_],
"conf" : score[p_],

"appe" : appe_embedding[p_].cpu().numpy(),
"pose" : pose_embedding[p_].numpy(),
"loca" : loca_embedding[p_].cpu().numpy(),
"uv" : uv_vector[p_].cpu().numpy(),

"embedding" : full_embedding[p_],
"center" : center_list[p_],
"scale" : scale_list[p_],
"smpl" : pred_smpl_params[p_],
"camera" : pred_cam_[p_].cpu().numpy(),
"camera_bbox" : hmar_out['pred_cam'][p_].cpu().numpy(),
"3d_joints" : pred_joints[p_].cpu().numpy(),
"2d_joints" : pred_joints_2d_[p_].cpu().numpy(),
"size" : [img_height, img_width],
"img_path" : frame_name,
"img_name" : frame_name.split('/')[-1] if isinstance(frame_name, str) else None,
"class_name" : cls_id[p_],
"time" : t_,

"ground_truth" : gt[p_],
"annotations" : ann[p_]
}
detection_data_list.append(Detection(detection_data))

return detection_data_list

def forward_for_tracking(self, vectors, attibute="A", time=1):

if(attibute=="P"):
Expand Down

0 comments on commit 2b77b4d

Please sign in to comment.