forked from patrikeh/go-deep
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathloss.go
114 lines (97 loc) · 2.46 KB
/
loss.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
package deep
import (
"math"
)
// GetLoss returns a loss function given a LossType
func GetLoss(loss LossType) Loss {
switch loss {
case LossCrossEntropy:
return CrossEntropy{}
case LossMeanSquared:
return MeanSquared{}
case LossBinaryCrossEntropy:
return BinaryCrossEntropy{}
}
return CrossEntropy{}
}
// LossType represents a loss function
type LossType int
func (l LossType) String() string {
switch l {
case LossCrossEntropy:
return "CE"
case LossBinaryCrossEntropy:
return "BinCE"
case LossMeanSquared:
return "MSE"
}
return "N/A"
}
const (
// LossNone signifies unspecified loss
LossNone LossType = 0
// LossCrossEntropy is cross entropy loss
LossCrossEntropy LossType = 1
// LossBinaryCrossEntropy is the special case of binary cross entropy loss
LossBinaryCrossEntropy LossType = 2
// LossMeanSquared is MSE
LossMeanSquared LossType = 3
)
// Loss is satisfied by loss functions
type Loss interface {
F(estimate, ideal [][]float64) float64
Df(estimate, ideal, activation float64) float64
}
// CrossEntropy is CE loss
type CrossEntropy struct{}
// F is CE(...)
func (l CrossEntropy) F(estimate, ideal [][]float64) float64 {
var sum float64
for i := range estimate {
ce := 0.0
for j := range estimate[i] {
ce += ideal[i][j] * math.Log(estimate[i][j])
}
sum -= ce
}
return sum / float64(len(estimate))
}
// Df is CE'(...)
func (l CrossEntropy) Df(estimate, ideal, activation float64) float64 {
return estimate - ideal
}
// BinaryCrossEntropy is binary CE loss
type BinaryCrossEntropy struct{}
// F is CE(...)
func (l BinaryCrossEntropy) F(estimate, ideal [][]float64) float64 {
epsilon := 1e-16
var sum float64
for i := range estimate {
ce := 0.0
for j := range estimate[i] {
ce += ideal[i][j]*math.Log(estimate[i][j]+epsilon) + (1.0-ideal[i][j])*math.Log(1.0-estimate[i][j]+epsilon)
}
sum -= ce
}
return sum / float64(len(estimate))
}
// Df is CE'(...)
func (l BinaryCrossEntropy) Df(estimate, ideal, activation float64) float64 {
return estimate - ideal
}
// MeanSquared in MSE loss
type MeanSquared struct{}
// F is MSE(...)
func (l MeanSquared) F(estimate, ideal [][]float64) float64 {
var sum float64
for i := 0; i < len(estimate); i++ {
for j := 0; j < len(estimate[i]); j++ {
sum += math.Pow(estimate[i][j]-ideal[i][j], 2)
}
}
return sum / float64(len(estimate)*len(estimate[0]))
}
// Df is MSE'(...)
func (l MeanSquared) Df(estimate, ideal, activation float64) float64 {
return activation * (estimate - ideal)
}