forked from jik876/hifi-gan
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmetrics.py
92 lines (67 loc) · 3.24 KB
/
metrics.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
import numpy as np
def compute_det_curve(target_scores, nontarget_scores):
n_scores = target_scores.size + nontarget_scores.size
all_scores = np.concatenate((target_scores, nontarget_scores))
labels = np.concatenate((np.ones(target_scores.size), np.zeros(nontarget_scores.size)))
# Sort labels based on scores
indices = np.argsort(all_scores, kind='mergesort')
labels = labels[indices]
# Compute false rejection and false acceptance rates
tar_trial_sums = np.cumsum(labels)
nontarget_trial_sums = nontarget_scores.size - (np.arange(1, n_scores + 1) - tar_trial_sums)
frr = np.concatenate((np.atleast_1d(0), tar_trial_sums / target_scores.size)) # false rejection rates
far = np.concatenate((np.atleast_1d(1), nontarget_trial_sums / nontarget_scores.size)) # false acceptance rates
thresholds = np.concatenate((np.atleast_1d(all_scores[indices[0]] - 0.001), all_scores[indices])) # Thresholds are the sorted scores
return frr, far, thresholds
def compute_eer(target_scores, nontarget_scores):
""" Returns equal error rate (EER) and the corresponding threshold. """
frr, far, thresholds = compute_det_curve(target_scores, nontarget_scores)
abs_diffs = np.abs(frr - far)
min_index = np.argmin(abs_diffs)
eer = np.mean((frr[min_index], far[min_index]))
return eer, thresholds[min_index]
class DiscriminatorMetrics():
def __init__(self):
self._scores_real = []
self._scores_fake = []
@property
def scores_real(self):
return torch.cat(self._scores_real, dim=0)
@property
def scores_fake(self):
return torch.cat(self._scores_fake, dim=0)
@property
def eer(self):
scores_real = self.scores_real.detach().cpu().numpy()
scores_fake = self.scores_fake.detach().cpu().numpy()
eer, thresholds = compute_eer(scores_real, scores_fake)
return eer
@property
def accuracy(self):
return 1.0 - self.eer
def accumulate(self, disc_real_outputs, disc_fake_outputs):
"""
Args:
disc_real_outputs:
shape is (batch, channels, timesteps)
disc_fake_outputs
shape is (batch, channels, timesteps)
"""
scores_real = []
scores_fake = []
# classifications for each discriminator
for d_real, d_fake in zip(disc_real_outputs, disc_fake_outputs):
# mean prediction over time and channels
scores_real.append(torch.mean(d_real, dim=(-1,)))
scores_fake.append(torch.mean(d_fake, dim=(-1,)))
# Stack scores from different discriminators
scores_real = torch.stack(scores_real, dim=1) # -> (batch, num_discriminators)
scores_fake = torch.stack(scores_fake, dim=1) # -> (batch, num_discriminators)
# Voting by averaging scores
scores_real_voted = torch.mean(scores_real, dim=-1) # -> (batch,)
scores_fake_voted = torch.mean(scores_fake, dim=-1) # -> (batch,)
if scores_real_voted.shape != scores_fake_voted.shape:
raise ValueError("Real and generated batch sizes must match")
self._scores_real.append(scores_real_voted)
self._scores_fake.append(scores_fake_voted)