forked from zhan-xu/RigNet
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
3 changed files
with
213 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
#------------------------------------------------------------------------------- | ||
# © 2019-2020 Zhan Xu | ||
# Name: skeleton_dataset.py | ||
# Purpose: torch_geometric dataset wrapper for skeleton training and inference | ||
# Licence: GNU General Public License v3 | ||
#------------------------------------------------------------------------------- | ||
import os | ||
import torch | ||
import numpy as np | ||
import glob | ||
from torch_geometric.data import Data, InMemoryDataset | ||
from torch_geometric.utils import add_self_loops | ||
from utils import binvox_rw | ||
|
||
|
||
class GraphDataset(InMemoryDataset): | ||
def __init__(self, root): | ||
super(GraphDataset, self).__init__(root) | ||
self.data, self.slices = torch.load(self.processed_paths[0]) | ||
|
||
@property | ||
def raw_file_names(self): | ||
raw_v_filelist = glob.glob(os.path.join(self.root, '*_v.txt')) | ||
return raw_v_filelist | ||
|
||
@property | ||
def processed_file_names(self): | ||
return '{:s}_skeleton_data.pt'.format(self.root.split('/')[-1]) | ||
|
||
def __len__(self): | ||
return len(self.raw_paths) | ||
|
||
def download(self): | ||
pass | ||
|
||
def sample_on_bone(self, p_pos, ch_pos): | ||
ray = ch_pos - p_pos | ||
bone_length = np.sqrt(np.sum((p_pos - ch_pos) ** 2)) | ||
num_step = np.round(bone_length / 0.01) | ||
i_step = np.arange(1, num_step + 1) | ||
unit_step = ray / (num_step + 1e-30) | ||
unit_step = np.repeat(unit_step[np.newaxis, :], num_step, axis=0) | ||
res = p_pos + unit_step * i_step[:, np.newaxis] | ||
return res | ||
|
||
def inside_check(self, pts, vox): | ||
vc = (pts - vox.translate) / vox.scale * vox.dims[0] | ||
vc = np.round(vc).astype(int) | ||
ind1 = np.logical_and(np.all(vc >= 0, axis=1), np.all(vc < vox.dims[0], axis=1)) | ||
vc = np.clip(vc, 0, vox.dims[0]-1) | ||
ind2 = vox.data[vc[:, 0], vc[:, 1], vc[:, 2]] | ||
ind = np.logical_and(ind1, ind2) | ||
pts = pts[ind] | ||
return pts, np.argwhere(ind).squeeze() | ||
|
||
def process(self): | ||
data_list = [] | ||
i = 0.0 | ||
for v_filename in self.raw_paths: | ||
print('preprecessing data complete: {:.4f}%'.format(100 * i / len(self.raw_paths))) | ||
i += 1.0 | ||
v = np.loadtxt(v_filename) | ||
m = np.loadtxt(v_filename.replace('_v.txt', '_attn.txt')) | ||
v = torch.from_numpy(v).float() | ||
m = torch.from_numpy(m).long() | ||
tpl_e = np.loadtxt(v_filename.replace('_v.txt', '_tpl_e.txt')).T | ||
geo_e = np.loadtxt(v_filename.replace('_v.txt', '_geo_e.txt')).T | ||
tpl_e = torch.from_numpy(tpl_e).long() | ||
geo_e = torch.from_numpy(geo_e).long() | ||
tpl_e, _ = add_self_loops(tpl_e, num_nodes=v.size(0)) | ||
geo_e, _ = add_self_loops(geo_e, num_nodes=v.size(0)) | ||
y = np.loadtxt(v_filename.replace('_v.txt', '_j.txt')) | ||
num_joint = len(y) | ||
joint_pos = y | ||
if len(y) < len(v): | ||
y = np.tile(y, (round(1.0 * len(v) / len(y) + 0.5), 1)) | ||
y = y[:len(v), :] | ||
elif len(y) > len(v): | ||
y = y[:len(v), :] | ||
y = torch.from_numpy(y).float() | ||
|
||
adj = np.loadtxt(v_filename.replace('_v.txt', '_adj.txt'), dtype=np.uint8) | ||
|
||
vox_file = v_filename.replace('_v.txt', '.binvox') | ||
with open(vox_file, 'rb') as fvox: | ||
vox = binvox_rw.read_as_3d_array(fvox) | ||
pair_all = [] | ||
for joint1_id in range(adj.shape[0]): | ||
for joint2_id in range(joint1_id + 1, adj.shape[1]): | ||
dist = np.linalg.norm(joint_pos[joint1_id] - joint_pos[joint2_id]) | ||
bone_samples = self.sample_on_bone(joint_pos[joint1_id], joint_pos[joint2_id]) | ||
bone_samples_inside, _ = self.inside_check(bone_samples, vox) | ||
outside_proportion = len(bone_samples_inside) / (len(bone_samples) + 1e-10) | ||
pair = np.array([joint1_id, joint2_id, dist, outside_proportion, adj[joint1_id, joint2_id]]) | ||
pair_all.append(pair) | ||
pair_all = np.array(pair_all) | ||
pair_all = torch.from_numpy(pair_all).float() | ||
num_pair = len(pair_all) | ||
|
||
name = int(v_filename.split('/')[-1].split('_')[0]) | ||
data_list.append(Data(x=v[:, 3:6], pos=v[:, 0:3], name=name, mask=m, y=y, num_joint=num_joint, | ||
tpl_edge_index=tpl_e, geo_edge_index=geo_e, pairs=pair_all, num_pair=num_pair)) | ||
data, slices = self.collate(data_list) | ||
torch.save((data, slices), self.processed_paths[0]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
#------------------------------------------------------------------------------- | ||
# © 2019-2020 Zhan Xu | ||
# Name: skin_dataset.py | ||
# Purpose: torch_geometric dataset wrapper for skinning training and inference | ||
# Licence: GNU General Public License v3 | ||
#------------------------------------------------------------------------------- | ||
import os | ||
import torch | ||
import numpy as np | ||
import glob | ||
from torch_geometric.data import Data, InMemoryDataset | ||
from torch_geometric.utils import add_self_loops | ||
|
||
|
||
class SkinDataset(InMemoryDataset): | ||
def __init__(self, root): | ||
super(SkinDataset, self).__init__(root) | ||
self.data, self.slices = torch.load(self.processed_paths[0]) | ||
|
||
@property | ||
def raw_file_names(self): | ||
raw_v_filelist = glob.glob(os.path.join(self.root, '*_v.txt')) | ||
return raw_v_filelist | ||
|
||
@property | ||
def processed_file_names(self): | ||
return '{:s}_skinning_data.pt'.format(self.root.split('/')[-1]) | ||
|
||
def __len__(self): | ||
return len(self.raw_paths) | ||
|
||
def download(self): | ||
pass | ||
|
||
def load_skin(self, filename): | ||
with open(filename, 'r') as fin: | ||
lines = fin.readlines() | ||
bones = [] | ||
input = [] | ||
label = [] | ||
nearest_bone_ids = [] | ||
loss_mask_all = [] | ||
for li in lines: | ||
words = li.strip().split() | ||
if words[0] == 'bones': | ||
bones.append([float(w) for w in words[3:]]) | ||
elif words[0] == 'bind': | ||
words = [float(w) for w in words[1:]] | ||
sample_input = [] | ||
sample_nearest_bone_ids = [] | ||
loss_mask = [] | ||
for i in range(self.num_nearest_bone): | ||
if int(words[3 * i + 1]) == -1: | ||
## go around. words[3] may be also invalid. | ||
sample_nearest_bone_ids.append(int(words[1])) | ||
sample_input += bones[int(words[1])] | ||
sample_input.append(words[2]) | ||
sample_input.append(int(words[3])) | ||
loss_mask.append(0) | ||
else: | ||
sample_nearest_bone_ids.append(int(words[3 * i + 1])) | ||
sample_input += bones[int(words[3 * i + 1])] | ||
sample_input.append(words[3 * i + 2]) | ||
sample_input.append(int(words[3 * i + 3])) | ||
loss_mask.append(1) | ||
input.append(np.array(sample_input)[np.newaxis, :]) | ||
nearest_bone_ids.append(np.array(sample_nearest_bone_ids)[np.newaxis, :]) | ||
loss_mask_all.append(np.array(loss_mask)[np.newaxis, :]) | ||
elif words[0] == 'influence': | ||
sample_label = np.array([float(w) for w in words[1:]])[np.newaxis, :] | ||
label.append(sample_label) | ||
|
||
input = np.concatenate(input, axis=0) | ||
nearest_bone_ids = np.concatenate(nearest_bone_ids, axis=0) | ||
label = np.concatenate(label, axis=0) | ||
loss_mask_all = np.concatenate(loss_mask_all, axis=0) | ||
|
||
return input, nearest_bone_ids, label, loss_mask_all | ||
|
||
def process(self): | ||
data_list = [] | ||
self.num_nearest_bone = 5 | ||
i = 0.0 | ||
for v_filename in self.raw_paths: | ||
print('preprecessing data complete: {:.4f}%'.format(100 * i / len(self.raw_paths))) | ||
i += 1.0 | ||
name = int(v_filename.split('/')[-1].split('_')[0]) | ||
v = np.loadtxt(v_filename) | ||
v = torch.from_numpy(v).float() | ||
tpl_e = np.loadtxt(v_filename.replace('_v.txt', '_tpl_e.txt')).T | ||
geo_e = np.loadtxt(v_filename.replace('_v.txt', '_geo_e.txt')).T | ||
tpl_e = torch.from_numpy(tpl_e).long() | ||
geo_e = torch.from_numpy(geo_e).long() | ||
tpl_e, _ = add_self_loops(tpl_e, num_nodes=v.size(0)) | ||
geo_e, _ = add_self_loops(geo_e, num_nodes=v.size(0)) | ||
skin_input, skin_nn, skin_label, loss_mask = self.load_skin(v_filename.replace('_v.txt', '_skin.txt')) | ||
|
||
skin_input = torch.from_numpy(skin_input).float() | ||
skin_label = torch.from_numpy(skin_label).float() | ||
skin_nn = torch.from_numpy(skin_nn).long() | ||
loss_mask = torch.from_numpy(loss_mask).long() | ||
num_skin = len(skin_input) | ||
|
||
name = int(v_filename.split('/')[-1].split('_')[0]) | ||
data_list.append(Data(x=v[:, 3:6], pos=v[:, 0:3], skin_input=skin_input, skin_label=skin_label, | ||
skin_nn=skin_nn, loss_mask=loss_mask, num_skin=num_skin, name=name, | ||
tpl_edge_index=tpl_e, geo_edge_index=geo_e)) | ||
data, slices = self.collate(data_list) | ||
torch.save((data, slices), self.processed_paths[0]) |