-
Notifications
You must be signed in to change notification settings - Fork 4
/
Copy pathuce_utils.py
38 lines (28 loc) · 1.06 KB
/
uce_utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# Max-Heinrich Laves
# Institute of Mechatronic Systems
# Leibniz Universität Hannover, Germany
# 2019
import torch
__all__ = ['accuracy', 'kl_loss', 'nentr', 'xavier_normal_init']
def accuracy(input, target):
_, max_indices = torch.max(input.data, 1)
acc = (max_indices == target).sum().float() / max_indices.size(0)
return acc.item()
def kl_loss(logits):
return -torch.nn.functional.log_softmax(logits, dim=1).mean()
def nentr(p, base=None):
"""
Calculates entropy of p to the base b. If base is None, the natural logarithm is used.
:param p: batches of class label probability distributions (softmax output)
:param base: base b
:return:
"""
eps = torch.tensor([1e-16], device=p.device)
if base:
base = torch.tensor([base], device=p.device, dtype=torch.float32)
return (p.mul(p.add(eps).log().div(base.log()))).sum(dim=1).abs()
else:
return (p.mul(p.add(eps).log())).sum(dim=1).abs()
def xavier_normal_init(m):
if isinstance(m, torch.nn.Conv2d):
torch.nn.init.xavier_normal_(m.weight.data)