forked from zhan-xu/RigNet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathGCN.py
67 lines (58 loc) · 3.06 KB
/
GCN.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
#-------------------------------------------------------------------------------
# Name: GCN.py
# Purpose: definition of joint prediction module.
# RigNet Copyright 2020 University of Massachusetts
# RigNet is made available under General Public License Version 3 (GPLv3), or under a Commercial License.
# Please see the LICENSE README.txt file in the main directory for more information and instruction on using and licensing RigNet.
#-------------------------------------------------------------------------------
import torch
from models.gcn_basic_modules import MLP, GCU
from torch_scatter import scatter_max, scatter_mean
from torch.nn import Sequential, Dropout, Linear, ReLU, Parameter
class JointPredNet(torch.nn.Module):
def __init__(self, out_channels, input_normal, arch, aggr='max'):
super(JointPredNet, self).__init__()
self.input_normal = input_normal
self.arch = arch
if self.input_normal:
self.input_channel = 6
else:
self.input_channel = 3
self.gcu_1 = GCU(in_channels=self.input_channel, out_channels=64, aggr=aggr)
self.gcu_2 = GCU(in_channels=64, out_channels=256, aggr=aggr)
self.gcu_3 = GCU(in_channels=256, out_channels=512, aggr=aggr)
# feature compression
self.mlp_glb = MLP([(64 + 256 + 512), 1024])
self.mlp_tramsform = Sequential(MLP([1024 + self.input_channel + 64 + 256 +512, 1024, 256]),
Dropout(0.7), Linear(256, out_channels))
def forward(self, data):
if self.input_normal:
x = torch.cat([data.pos, data.x], dim=1)
else:
x = data.pos
geo_edge_index, tpl_edge_index, batch = data.geo_edge_index, data.tpl_edge_index, data.batch
x_1 = self.gcu_1(x, tpl_edge_index, geo_edge_index)
x_2 = self.gcu_2(x_1, tpl_edge_index, geo_edge_index)
x_3 = self.gcu_3(x_2, tpl_edge_index, geo_edge_index)
x_4 = self.mlp_glb(torch.cat([x_1, x_2, x_3], dim=1))
x_global, _ = scatter_max(x_4, data.batch, dim=0)
#x_global_mean = scatter_mean(x_4, data.batch, dim=0)
#x_global = torch.cat([x_global_max, x_global_mean], dim=1)
x_global = torch.repeat_interleave(x_global, torch.bincount(data.batch), dim=0)
x_5 = torch.cat([x_global, x, x_1, x_2, x_3], dim=1)
out = self.mlp_tramsform(x_5)
if self.arch == 'jointnet':
out = torch.tanh(out)
return out
class JOINTNET_MASKNET_MEANSHIFT(torch.nn.Module):
def __init__(self):
super(JOINTNET_MASKNET_MEANSHIFT, self).__init__()
self.jointnet = JointPredNet(3, input_normal=False, arch='jointnet', aggr='max')
self.masknet = JointPredNet(1, input_normal=False, arch='masknet', aggr='max')
self.bandwidth = Parameter(torch.Tensor(1))
self.bandwidth.data.fill_(0.04)
def forward(self, data):
x_offset = self.jointnet(data)
x_mask_prob_0 = self.masknet(data)
x_mask_prob = torch.sigmoid(x_mask_prob_0)
return x_offset, x_mask_prob_0, x_mask_prob, self.bandwidth