This repository has been archived by the owner on Nov 10, 2023. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add uniform and Gaussian weight initializers
- Loading branch information
1 parent
5efcbbd
commit 8c6ec31
Showing
3 changed files
with
66 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
package initwfn | ||
|
||
import G "gorgonia.org/gorgonia" | ||
|
||
// Gaussian implements a configuration of a weight initializer that | ||
// draws weights from a gaussian distribution | ||
type GaussianConfig struct { | ||
Mean, StdDev float64 | ||
} | ||
|
||
// NewGaussian returns a new gaussian weight initializer | ||
func NewGaussian(mean, stddev float64) (*InitWFn, error) { | ||
config := GaussianConfig{ | ||
Mean: mean, | ||
StdDev: stddev, | ||
} | ||
|
||
return newInitWFn(config) | ||
} | ||
|
||
// Type returns the type of initialization algorithm described by | ||
// the configuration. | ||
func (u GaussianConfig) Type() Type { | ||
return Gaussian | ||
} | ||
|
||
// Create returns the weight initialization algorithm as a Gorgonia | ||
// InitWFn | ||
func (u GaussianConfig) Create() G.InitWFn { | ||
return G.Gaussian(u.Mean, u.StdDev) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
package initwfn | ||
|
||
import G "gorgonia.org/gorgonia" | ||
|
||
// Uniform implements a configuration of a weight initializer that | ||
// draws weights from a uniform distribution | ||
type UniformConfig struct { | ||
Low, High float64 | ||
} | ||
|
||
// NewUniform returns a new uniform weight initializer | ||
func NewUniform(low, high float64) (*InitWFn, error) { | ||
config := UniformConfig{ | ||
Low: low, | ||
High: high, | ||
} | ||
|
||
return newInitWFn(config) | ||
} | ||
|
||
// Type returns the type of initialization algorithm described by | ||
// the configuration. | ||
func (u UniformConfig) Type() Type { | ||
return Uniform | ||
} | ||
|
||
// Create returns the weight initialization algorithm as a Gorgonia | ||
// InitWFn | ||
func (u UniformConfig) Create() G.InitWFn { | ||
return G.Uniform(u.Low, u.High) | ||
} |