-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtopological_attrib_t.py
executable file
·101 lines (73 loc) · 3.01 KB
/
topological_attrib_t.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import torch
import torch.nn as nn
import copy
import torch.nn.functional as F
import torch.autograd as autograd
def get_attrib(net, graph, features, labels):
"""This is the function that computes the topological attributions for the teacher
Args:
net (nn.Module): the teacher GNN model
graph (DGLGraph): the input graphs containing the topological information
features (torch.Tensor): the input node features
labels (torch.Tensor): the soft labels
Returns:
torch.Tensor: topological attributions
"""
labels = torch.where(labels > 0.0,
torch.ones(labels.shape).to(labels.device),
torch.zeros(labels.shape).to(labels.device)).type(torch.bool)
# zero gradients
if net.g.edata['e_grad'].grad is not None:
net.g.edata['e_grad'].grad.zero_()
# generate model outputs
output = net(features.float())
# set the gradients of the corresponding output activations to one
output_grad = torch.zeros_like(output)
output_grad[labels] = 1
# compute the gradients
attrib = autograd.grad(outputs=output, inputs=net.g.edata['e_grad'], grad_outputs=output_grad, only_inputs=True)[0]
return attrib
class ATTNET_t(nn.Module):
"""This is the class that returns the topological attribution maps of the teacher GNN
Args:
nn.Module: torch module
"""
def __init__(self,
model,
args):
super(ATTNET_t, self).__init__()
# set up the network
self.net = model
def forward(self, graph, features):
"""This is the forward function of ATTNET_t
Args:
graph (DGLGraph): the input graphs containing the topological information
features (torch.Tensor): the input node features
Returns:
torch.Tensor: the generated logits of the model
"""
self.net.g = graph
for layer in self.net.gat_layers:
layer.g = graph
output = self.net(features)
return output
def observe(self, graph, features, labels):
"""This is the function that returns the topological attribution maps
Args:
graph (DGLGraph): the input graphs containing the topological information
features (torch.Tensor): the input node features
labels (torch.Tensor): the soft labels
Returns:
torch.Tensor: topological attributions
"""
self.net.train()
self.net.g = graph
for layer in self.net.gat_layers:
layer.g = graph
# set auxiliary unary features for the edges to obtain the topological attributions
self.net.g.edata['e_grad'] = torch.cuda.FloatTensor( [1.0]
* self.net.g.number_of_edges() ).view((-1, 1, 1))
self.net.g.edata['e_grad'].requires_grad = True
attrib = get_attrib(self.net, graph, features, labels)
self.net.eval()
return attrib