Skip to content
This repository has been archived by the owner on Nov 10, 2023. It is now read-only.

Commit

Permalink
Add weight initialization functions
Browse files Browse the repository at this point in the history
Added the following weight initialization functions:
* Zeroes
* Ones
* HeU
* HeN
  • Loading branch information
samuelfneumann committed Sep 24, 2021
1 parent 5b8842c commit d8fb213
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 25 deletions.
48 changes: 48 additions & 0 deletions initwfn/Constant.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
package initwfn

import G "gorgonia.org/gorgonia"

// ZeroesConfig implements a configuration of a zero weight initializer
type ZeroesConfig struct{}

// Zeroes returns a new zeroes weight intializer
func NewZeroes() (*InitWFn, error) {
config := ZeroesConfig{}

return newInitWFn(config)
}

// Type returns the type of the weight initializer created using this
// config
func (z ZeroesConfig) Type() Type {
return Zeroes
}

// Create creates the Gorgonia weight initializer from this
// initializer config
func (z ZeroesConfig) Create() G.InitWFn {
return G.Zeroes()
}

// OnesConfig implements a configuration of a weight initializer that
// initializes all weights to 1.
type OnesConfig struct{}

// Ones returns a new zeroes weight intializer
func NewOnes() (*InitWFn, error) {
config := OnesConfig{}

return newInitWFn(config)
}

// Type returns the type of the weight initializer created using this
// config
func (o OnesConfig) Type() Type {
return Ones
}

// Create creates the Gorgonia weight initializer from this
// initializer config
func (o OnesConfig) Create() G.InitWFn {
return G.Ones()
}
57 changes: 57 additions & 0 deletions initwfn/He.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
package initwfn

import G "gorgonia.org/gorgonia"

// HeUConfig implements a configuration of the He uniform
// initialization algorithm.
type HeUConfig struct {
Gain float64
}

// NewHeU returns a new He Uniform weight initializer
func NewHeU(gain float64) (*InitWFn, error) {
config := HeUConfig{
Gain: gain,
}

return newInitWFn(config)
}

// Type returns the type of initialization algorithm described by
// the configuration.
func (h HeUConfig) Type() Type {
return HeU
}

// Create returns the weight initialization algorithm as a Gorgonia
// InitWFn
func (h HeUConfig) Create() G.InitWFn {
return G.HeU(h.Gain)
}

// HeNConfig implements a configuration of the He normal
// initialization algorithm.
type HeNConfig struct {
Gain float64
}

// NewHeN returns a new He Nniform weight initializer
func NewHeN(gain float64) (*InitWFn, error) {
config := HeNConfig{
Gain: gain,
}

return newInitWFn(config)
}

// Type returns the type of initialization algorithm described by
// the configuration.
func (h HeNConfig) Type() Type {
return HeN
}

// Create returns the weight initialization algorithm as a Gorgonia
// InitWFn
func (h HeNConfig) Create() G.InitWFn {
return G.HeN(h.Gain)
}
7 changes: 7 additions & 0 deletions initwfn/InitWFn.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,10 @@ type Type string
const (
GlorotU Type = "GlorotU"
GlorotN Type = "GlorotN"
HeU Type = "HeU"
HeN Type = "HeN"
Zeroes Type = "Zeroes"
Ones Type = "Ones"
)

// InitWFn wraps Gorgonia InitWFn so that they can be JSON marshalled and
Expand Down Expand Up @@ -56,6 +59,10 @@ func (i *InitWFn) UnmarshalJSON(data []byte) error {
map[string]reflect.Type{
string(GlorotU): reflect.TypeOf(GlorotUConfig{}),
string(GlorotN): reflect.TypeOf(GlorotNConfig{}),
string(HeU): reflect.TypeOf(HeUConfig{}),
string(HeN): reflect.TypeOf(HeNConfig{}),
string(Zeroes): reflect.TypeOf(ZeroesConfig{}),
string(Ones): reflect.TypeOf(OnesConfig{}),
})
if err != nil {
return err
Expand Down
25 changes: 0 additions & 25 deletions initwfn/Zeroes.go

This file was deleted.

0 comments on commit d8fb213

Please sign in to comment.