Skip to content

Commit

Permalink
Revert "try to train but wrong"
Browse files Browse the repository at this point in the history
This reverts commit 47ff0fa.
  • Loading branch information
cxTAKA committed Dec 21, 2023
1 parent 47ff0fa commit f0ab3fb
Show file tree
Hide file tree
Showing 7 changed files with 19 additions and 20 deletions.
4 changes: 2 additions & 2 deletions SocketForUnity.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,8 @@ def main():
# save_pkl2json(testfilepath)

################################------npz数据文件读写测试
# testfilepath = os.path.join("data_split","CMU","01","01_01_poses.npz")
# save_npz2json(testfilepath)
testfilepath = os.path.join("data_split","CMU","01","01_01_poses.npz")
save_npz2json(testfilepath)
return

if __name__ == '__main__':
Expand Down
17 changes: 8 additions & 9 deletions data/dataset_amass.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ def __getitem__(self, idx):
data = pickle.load(f)

if self.opt['phase'] == 'train':
#当文件长度小于窗口大小的时候,随机加载之前的文件
while data['rotation_local_full_gt_list'].shape[0] <self.window_size:
idx = random.randint(0,idx)
filename = self.filename_list[idx]
Expand All @@ -67,21 +66,21 @@ def __getitem__(self, idx):
head_global_trans_list = data['head_global_trans_list']


if self.opt['phase'] == 'train': # train
#训练时 随机截取一段
frame = np.random.randint(hmd_position_global_full_gt_list.shape[0] - self.window_size)
input_hmd = hmd_position_global_full_gt_list[frame:frame + self.window_size+1,...].reshape(self.window_size+1, -1)
output_gt = rotation_local_full_gt_list[frame : frame + self.window_size + 1,...]
if self.opt['phase'] == 'train':

frame = np.random.randint(hmd_position_global_full_gt_list.shape[0] - self.window_size + 1 - 1)
input_hmd = hmd_position_global_full_gt_list[frame:frame + self.window_size+1,...].reshape(self.window_size+1, -1).float()
output_gt = rotation_local_full_gt_list[frame + self.window_size - 1 : frame + self.window_size - 1 + 1,...].float()

return {'L': input_hmd.float(),
'H': output_gt.float(),
return {'L': input_hmd,
'H': output_gt,
'P': 1,
'Head_trans_global':head_global_trans_list[frame + self.window_size - 1:frame + self.window_size - 1+1,...],
'pos_pelvis_gt':body_parms_list['trans'][frame + self.window_size - 1:frame + self.window_size - 1+1,...],
'vel_pelvis_gt':body_parms_list['trans'][frame + self.window_size - 1:frame + self.window_size - 1+1,...]-body_parms_list['trans'][frame + self.window_size - 2:frame + self.window_size - 2+1,...]
}

else: # test
else:

input_hmd = hmd_position_global_full_gt_list.reshape(hmd_position_global_full_gt_list.shape[0], -1)[1:]
output_gt = rotation_local_full_gt_list[1:]
Expand Down
2 changes: 1 addition & 1 deletion main_train_avatarposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ def main(json_path='options/train_avatarposer.json'):
shuffle=dataset_opt['dataloader_shuffle'],
num_workers=dataset_opt['dataloader_num_workers'],
drop_last=True,
pin_memory=False)
pin_memory=True)
elif phase == 'test':
test_set = define_Dataset(dataset_opt)
test_loader = DataLoader(test_set, batch_size=dataset_opt['dataloader_batch_size'],
Expand Down
1 change: 0 additions & 1 deletion models/loss.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import math
import torch
import torch.nn as nn
import torchvision
Expand Down
4 changes: 2 additions & 2 deletions models/model_avatarposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ def feed_data(self, data, need_H=True, test=False):
self.L = data['L'].to(self.device)
self.P = data['P']
self.Head_trans_global = data['Head_trans_global'].to(self.device)
self.H_global_orientation = data['H'][:,:6].to(self.device)
self.H_joint_rotation = data['H'][:,6:].to(self.device)
self.H_global_orientation = data['H'].squeeze()[:,:6].to(self.device)
self.H_joint_rotation = data['H'].squeeze()[:,6:].to(self.device)
# self.H = torch.cat([self.H_global_orientation, self.H_joint_rotation],dim=-1).to(self.device)
self.H_joint_position = self.netG.module.fk_module(self.H_global_orientation, self.H_joint_rotation , self.bm)

Expand Down
4 changes: 2 additions & 2 deletions options/train_avatarposer.json
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
, "dataroot": "./data_fps60"// path of training dataset

, "dataloader_shuffle": true
, "dataloader_num_workers": 1
, "dataloader_batch_size": 1 // batch size 1 | 16 | 32 | 48 | 64 | 128 | 256
, "dataloader_num_workers": 16
, "dataloader_batch_size": 256 // batch size 1 | 16 | 32 | 48 | 64 | 128 | 256
, "num_input": 3
, "window_size": 40

Expand Down
7 changes: 4 additions & 3 deletions prepare_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,27 +104,28 @@
rotation_global_matrot = local2global_pose(rotation_local_matrot, bm.kintree_table[0].long()) # rotation of joints relative to the origin

head_rotation_global_matrot = rotation_global_matrot[:,[15],:,:]
#

rotation_global_6d = utils_transform.matrot2sixd(rotation_global_matrot.reshape(-1,3,3)).reshape(rotation_global_matrot.shape[0],-1,6)
input_rotation_global_6d = rotation_global_6d[1:,[15,20,21],:]

rotation_velocity_global_matrot = torch.matmul(torch.inverse(rotation_global_matrot[:-1]),rotation_global_matrot[1:])
rotation_velocity_global_6d = utils_transform.matrot2sixd(rotation_velocity_global_matrot.reshape(-1,3,3)).reshape(rotation_velocity_global_matrot.shape[0],-1,6)
input_rotation_velocity_global_6d = rotation_velocity_global_6d[:,[15,20,21],:]

#
position_global_full_gt_world = body_pose_world.Jtr[:,:22,:] # position of joints relative to the world origin

position_head_world = position_global_full_gt_world[:,15,:] # world position of head

head_global_trans = torch.eye(4).repeat(position_head_world.shape[0],1,1)
head_global_trans[:,:3,:3] = head_rotation_global_matrot.squeeze()
head_global_trans[:,:3,3] = position_global_full_gt_world[:,15,:]

# embed()
head_global_trans_list = head_global_trans[1:]




num_frames = position_global_full_gt_world.shape[0]-1


Expand Down

0 comments on commit f0ab3fb

Please sign in to comment.