Skip to content
This repository has been archived by the owner on Nov 10, 2023. It is now read-only.

Commit

Permalink
Add constant weight initializer
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelfneumann committed Sep 24, 2021
1 parent fe84326 commit 5efcbbd
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 12 deletions.
25 changes: 25 additions & 0 deletions initwfn/Constant.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,3 +46,28 @@ func (o OnesConfig) Type() Type {
func (o OnesConfig) Create() G.InitWFn {
return G.Ones()
}

// ConstantConfig implements a configuration of a weight initializer
// that initializes all weights to a constant value.
type ConstantConfig struct {
Value float64
}

// Constant returns a new zeroes weight intializer
func NewConstant(value float64) (*InitWFn, error) {
config := ConstantConfig{value}

return newInitWFn(config)
}

// Type returns the type of the weight initializer created using this
// config
func (c ConstantConfig) Type() Type {
return Constant
}

// Create creates the Gorgonia weight initializer from this
// initializer config
func (c ConstantConfig) Create() G.InitWFn {
return G.ValuesOf(c.Value)
}
26 changes: 14 additions & 12 deletions initwfn/InitWFn.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@ type Type string

// Available InitWFn types
const (
GlorotU Type = "GlorotU"
GlorotN Type = "GlorotN"
HeU Type = "HeU"
HeN Type = "HeN"
Zeroes Type = "Zeroes"
Ones Type = "Ones"
GlorotU Type = "GlorotU"
GlorotN Type = "GlorotN"
HeU Type = "HeU"
HeN Type = "HeN"
Zeroes Type = "Zeroes"
Ones Type = "Ones"
Constant Type = "Constant"
)

// InitWFn wraps Gorgonia InitWFn so that they can be JSON marshalled and
Expand Down Expand Up @@ -57,12 +58,13 @@ func (i *InitWFn) UnmarshalJSON(data []byte) error {
"Type",
"Config",
map[string]reflect.Type{
string(GlorotU): reflect.TypeOf(GlorotUConfig{}),
string(GlorotN): reflect.TypeOf(GlorotNConfig{}),
string(HeU): reflect.TypeOf(HeUConfig{}),
string(HeN): reflect.TypeOf(HeNConfig{}),
string(Zeroes): reflect.TypeOf(ZeroesConfig{}),
string(Ones): reflect.TypeOf(OnesConfig{}),
string(GlorotU): reflect.TypeOf(GlorotUConfig{}),
string(GlorotN): reflect.TypeOf(GlorotNConfig{}),
string(HeU): reflect.TypeOf(HeUConfig{}),
string(HeN): reflect.TypeOf(HeNConfig{}),
string(Zeroes): reflect.TypeOf(ZeroesConfig{}),
string(Ones): reflect.TypeOf(OnesConfig{}),
string(Constant): reflect.TypeOf(ConstantConfig{}),
})
if err != nil {
return err
Expand Down

0 comments on commit 5efcbbd

Please sign in to comment.