Skip to content

Commit

Permalink
Fix again to rank computation
Browse files Browse the repository at this point in the history
  • Loading branch information
Arun Tejasvi Chaganty committed Jun 1, 2020
1 parent 36c1eab commit 14b8ed7
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion riemann/evaluations/mean_rank.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def process_batch(self, batch: GraphDataBatch):
n_neighbors = (train_distances < 2).sum(dim=-1)

for row in (sorted_indices < n_neighbors.unsqueeze(1)):
ranks = (row.nonzero().squeeze() + 1).cpu().to(torch.float32)
ranks = (row.nonzero().squeeze(-1) + 1).cpu().to(torch.float32)
if len(ranks) > 0:
adjusted_ranks = (ranks - torch.arange(len(ranks), dtype=torch.float32))
self.hitsat10 += (ranks <= 10).sum().numpy()
Expand Down

0 comments on commit 14b8ed7

Please sign in to comment.