Skip to content

Commit

Permalink
solve filter
Browse files Browse the repository at this point in the history
  • Loading branch information
yichen14 committed Mar 2, 2023
1 parent a599f4f commit 499440a
Showing 1 changed file with 24 additions and 10 deletions.
34 changes: 24 additions & 10 deletions modules/utils/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,29 @@ def sigmoid(x):
tscores = []
fscores = []
test_len = len(edges_pos)

edges_p_all = torch.Tensor()
edges_n_all = torch.Tensor()
for edges_p in edges_pos:
edges_p_all = torch.cat((edges_p_all, edges_p.cpu().detach()))

for edges_n in edges_neg:
edges_n_all = torch.cat((edges_n_all, edges_n.cpu().detach()))

t_src_all, t_dst_all = edges_p_all.T
f_src_all, f_dst_all = edges_n_all.T

node_dict = {}
idx_dict = {}

for idx, node in enumerate(t_src_all):
if int(node) not in node_dict:
node_dict[int(node)] = []
node_dict[int(node)].append(int(t_dst_all[idx]))
if int(node) not in idx_dict:
idx_dict[int(node)] = []
idx_dict[int(node)].append(idx)

for i in range(test_len):
# Predict on test set of edges
emb = embs[i].cpu().detach() # n * d
Expand Down Expand Up @@ -83,16 +106,7 @@ def sigmoid(x):
rank_scores[range(len(t_src)), list(t_src)] = 0

if self.filter_flag:
node_dict = {}
idx_dict = {}

for idx, node in enumerate(t_src):
if int(node) not in node_dict:
node_dict[int(node)] = []
node_dict[int(node)].append(int(t_dst[idx]))
if int(node) not in idx_dict:
idx_dict[int(node)] = []
idx_dict[int(node)].append(idx)


t_one_hot = F.one_hot(t_dst.long(), num_classes=rank_scores.size(1)).to(torch.float32)
mask = torch.ones_like(rank_scores)
Expand Down

0 comments on commit 499440a

Please sign in to comment.