Skip to content
This repository has been archived by the owner on Nov 10, 2023. It is now read-only.

Commit

Permalink
Add uniform and Gaussian weight initializers
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelfneumann committed Sep 24, 2021
1 parent 5efcbbd commit 8c6ec31
Show file tree
Hide file tree
Showing 3 changed files with 66 additions and 0 deletions.
31 changes: 31 additions & 0 deletions initwfn/Gaussian.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package initwfn

import G "gorgonia.org/gorgonia"

// Gaussian implements a configuration of a weight initializer that
// draws weights from a gaussian distribution
type GaussianConfig struct {
Mean, StdDev float64
}

// NewGaussian returns a new gaussian weight initializer
func NewGaussian(mean, stddev float64) (*InitWFn, error) {
config := GaussianConfig{
Mean: mean,
StdDev: stddev,
}

return newInitWFn(config)
}

// Type returns the type of initialization algorithm described by
// the configuration.
func (u GaussianConfig) Type() Type {
return Gaussian
}

// Create returns the weight initialization algorithm as a Gorgonia
// InitWFn
func (u GaussianConfig) Create() G.InitWFn {
return G.Gaussian(u.Mean, u.StdDev)
}
4 changes: 4 additions & 0 deletions initwfn/InitWFn.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@ const (
Zeroes Type = "Zeroes"
Ones Type = "Ones"
Constant Type = "Constant"
Uniform Type = "Uniform"
Gaussian Type = "Gaussian"
)

// InitWFn wraps Gorgonia InitWFn so that they can be JSON marshalled and
Expand Down Expand Up @@ -65,6 +67,8 @@ func (i *InitWFn) UnmarshalJSON(data []byte) error {
string(Zeroes): reflect.TypeOf(ZeroesConfig{}),
string(Ones): reflect.TypeOf(OnesConfig{}),
string(Constant): reflect.TypeOf(ConstantConfig{}),
string(Uniform): reflect.TypeOf(UniformConfig{}),
string(Gaussian): reflect.TypeOf(GaussianConfig{}),
})
if err != nil {
return err
Expand Down
31 changes: 31 additions & 0 deletions initwfn/Uniform.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package initwfn

import G "gorgonia.org/gorgonia"

// Uniform implements a configuration of a weight initializer that
// draws weights from a uniform distribution
type UniformConfig struct {
Low, High float64
}

// NewUniform returns a new uniform weight initializer
func NewUniform(low, high float64) (*InitWFn, error) {
config := UniformConfig{
Low: low,
High: high,
}

return newInitWFn(config)
}

// Type returns the type of initialization algorithm described by
// the configuration.
func (u UniformConfig) Type() Type {
return Uniform
}

// Create returns the weight initialization algorithm as a Gorgonia
// InitWFn
func (u UniformConfig) Create() G.InitWFn {
return G.Uniform(u.Low, u.High)
}

0 comments on commit 8c6ec31

Please sign in to comment.