Skip to content
This repository has been archived by the owner on Nov 10, 2023. It is now read-only.

Commit

Permalink
Add RMSProp solver
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelfneumann committed Sep 24, 2021
1 parent 3161ba8 commit fe84326
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 2 deletions.
3 changes: 1 addition & 2 deletions solver/AdamSolver.go → solver/Adam.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,7 @@ func NewDefaultAdam(stepSize float64, batchSize int) (*Solver, error) {

// NewAdam returns a new Adam Solver
func NewAdam(stepSize, epsilon, beta1, beta2 float64, batchSize int,
clip float64) (*Solver,
error) {
clip float64) (*Solver, error) {
adam := AdamConfig{
StepSize: stepSize,
Epsilon: epsilon,
Expand Down
76 changes: 76 additions & 0 deletions solver/RMSProp.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
package solver

import (
"fmt"

G "gorgonia.org/gorgonia"
)

// RMSProprConfig implements a specific configuration of the RMSProp
// solver
type RMSPropConfig struct {
StepSize float64
Epsilon float64
Eta float64 // Only default value of 0.001 supported by Gorgonia
Rho float64
Batch int
Clip float64 // <= 0 if no clipping
}

// NewDefaultRMSProp returns a new RMSProp Solver with default
// hyperparameters
func NewDefaultRMSProp(stepSize float64, batchSize int) (*Solver, error) {
return NewRMSProp(stepSize, 1e-8, 0.001, 0.999, batchSize, -1.0)
}

// NewRMSProp returns a new RMSProp Solver
func NewRMSProp(stepSize, epsilon, eta, rho float64, batchSize int,
clip float64) (*Solver, error) {
if eta != 0.001 {
return nil, fmt.Errorf("newRMSProp: only the default value of " +
"η = 0.001 is currently supported")
}

rmsprop := RMSPropConfig{
StepSize: stepSize,
Epsilon: epsilon,
Eta: eta,
Rho: rho,
Batch: int(batchSize),
Clip: clip,
}

return newSolver(RMSProp, rmsprop)
}

// Create returns a new Gorgonia RMSProp Solver as described by the
// RMSPropConfig
func (r RMSPropConfig) Create() G.Solver {
var solver G.Solver

if r.Clip <= 0 {
solver = G.NewRMSPropSolver(
G.WithLearnRate(r.StepSize),
G.WithEps(r.Epsilon),
// G.WithEta(r.Eta), // Unsupported by Gorgonia
G.WithRho(r.Rho),
G.WithBatchSize(float64(r.Batch)),
)
} else {
solver = G.NewAdamSolver(
G.WithLearnRate(r.StepSize),
G.WithEps(r.Epsilon),
// G.WithEta(r.Eta), // Unsupported by Gorgonia
G.WithRho(r.Rho),
G.WithBatchSize(float64(r.Batch)),
G.WithClip(r.Clip),
)
}
return solver
}

// ValidType returns if the given Solver type is a valid type to be
// created with this config.
func (r RMSPropConfig) ValidType(t Type) bool {
return t == RMSProp
}
2 changes: 2 additions & 0 deletions solver/Solver.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ type Type string
// Available solver types
const (
Adam Type = "Adam"
RMSProp Type = "RMSProp"
Vanilla Type = "Vanilla"
)

Expand Down Expand Up @@ -53,6 +54,7 @@ func (s *Solver) UnmarshalJSON(data []byte) error {
map[string]reflect.Type{
string(Vanilla): reflect.TypeOf(VanillaConfig{}),
string(Adam): reflect.TypeOf(AdamConfig{}),
string(RMSProp): reflect.TypeOf(RMSPropConfig{}),
})
if err != nil {
return err
Expand Down
File renamed without changes.

0 comments on commit fe84326

Please sign in to comment.