Skip to content
This repository has been archived by the owner on Nov 10, 2023. It is now read-only.

Commit

Permalink
Fix activations with illegal input values
Browse files Browse the repository at this point in the history
Some activations have a range that is not the real numbers. These
activations would fail upon running. Now, the input nodes are first
transformed (using an absolute value) before taking the activation so
that the input values are in the valid range of the activation function.
  • Loading branch information
samuelfneumann committed Sep 25, 2021
1 parent 3babb20 commit 8dc7e7f
Showing 1 changed file with 43 additions and 21 deletions.
64 changes: 43 additions & 21 deletions network/Activations.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,17 @@ import (
type activationType string

const (
relu activationType = "relu"
softplus activationType = "softplu"
identity activationType = "identity"
tanh activationType = "tanh"
logarithm activationType = "log"
sigmoid activationType = "sigmoid"
sin activationType = "sin"
cos activationType = "cos"
sqrt activationType = "sqrt"
mish activationType = "mish"
nil_ activationType = "nil"
relu activationType = "relu"
softplus activationType = "softplu"
identity activationType = "identity"
tanh activationType = "tanh"
log1p activationType = "log1p"
sigmoid activationType = "sigmoid"
sin activationType = "sin"
cos activationType = "cos"
sqrt activationType = "sqrt"
mish activationType = "mish"
nil_ activationType = "nil"
)

// Activation represents an activation function type
Expand Down Expand Up @@ -85,8 +85,8 @@ func (a *Activation) UnmarshalJSON(data []byte) error {
case mish:
*a = *Mish()

case logarithm:
*a = *Log()
case log1p:
*a = *Log1p()

case softplus:
*a = *Softplus()
Expand Down Expand Up @@ -130,8 +130,8 @@ func (a *Activation) GobDecode(encoded []byte) error {
case mish:
*a = *Mish()

case logarithm:
*a = *Log()
case log1p:
*a = *Log1p()

case softplus:
*a = *Softplus()
Expand Down Expand Up @@ -200,11 +200,22 @@ func Cos() *Activation {
}
}

// Sqrt returns a sqare root activation
// Sqrt returns a sqare root activation. Nodes are first passed through
// an absolute value activation before taking the square root.
func Sqrt() *Activation {
sqrtFn := func(n *G.Node) (*G.Node, error) {
var err error

if n, err = G.Abs(n); err != nil {
return nil, fmt.Errorf("sqrt: could not compute absolute "+
"value %v", err)
}
return G.Sqrt(n)
}

return &Activation{
activationType: sqrt,
f: G.Sqrt,
f: sqrtFn,
}
}

Expand All @@ -216,11 +227,22 @@ func Mish() *Activation {
}
}

// Log returns a logarithm activation
func Log() *Activation {
// Log1p returns a log(|x| + 1) activation. The input node is first
// passed through an absolute value activation, then a 1 is added to
// each element of the input node. Finally the log of the result is
// taken.
func Log1p() *Activation {
logFn := func(n *G.Node) (*G.Node, error) {
if n, err := G.Abs(n); err != nil {
return nil, fmt.Errorf("log1p: could not compute absolute "+
"value: %v", err)
}
return G.Log1p(n)
}

return &Activation{
activationType: logarithm,
f: G.Log,
activationType: log1p,
f: logFn,
}
}

Expand Down

0 comments on commit 8dc7e7f

Please sign in to comment.