Skip to content

Commit

Permalink
wrapped all the randomizers in training.clj
Browse files Browse the repository at this point in the history
  • Loading branch information
jimpil committed May 17, 2012
1 parent 44f3d8b commit 65b091d
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 13 deletions.
4 changes: 2 additions & 2 deletions src/clojure_encog/examples.clj
Original file line number Diff line number Diff line change
Expand Up @@ -20,10 +20,10 @@
trainer ((make-trainer :resilient-prop) network dataset)]
;;make use of the boolean parameter
(if-not keep-trying?
(train trainer 0.01 300 #_(RequiredImprovementStrategy. 5)) ;;train the network once
(train trainer 0.01 300 [] #_[(RequiredImprovementStrategy. 5)]) ;;train the network once
(loop [t false counter 0 _ nil]
(if t (println "Nailed it after" (str counter) "times!")
(recur (train trainer 0.01 300 #_(RequiredImprovementStrategy. 5)) ;;train the network until it succeeds
(recur (train trainer 0.01 300 [] #_[(RequiredImprovementStrategy. 5)]) ;;train the network until it succeeds
(inc counter) (. network reset)))) )
(do (println "\nNeural Network Results:")
(doseq [pair dataset]
Expand Down
63 changes: 52 additions & 11 deletions src/clojure_encog/training.clj
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@
(org.encog.ml.svm.training SVMTrain)
(org.encog.ml.svm SVM)
(org.encog.ml MLRegression)
(org.encog.util.simple EncogUtility)
(org.encog.neural.networks BasicNetwork)
(org.encog.mathutil.randomize Randomizer BasicRandomizer ConsistentRandomizer
ConstRandomizer FanInRandomizer Distort GaussianRandomizer
NguyenWidrowRandomizer RangeRandomizer)
(org.encog.util.arrayutil NormalizeArray TemporalWindowArray)))


Expand All @@ -43,6 +47,26 @@
:else (throw (IllegalArgumentException. "Unsupported data model!"))
))

(defn make-randomizer
"Constructs a Randomizer object. "
[type]
(condp = type
:basic (BasicRandomizer.) ;random number generator with a random(current time) seed
:range (fn [min-val max-val] (RangeRandomizer. min-val max-val)) ;range randomizer
:consistent (fn [min-rand max-rand] (ConsistentRandomizer. min-rand max-rand)) ;consistent range randomizer
:constant (fn [constant] (ConstRandomizer. constant))
:distort (fn [factor] (Distort. factor))
:fan-in (fn [boundary sqr-root?] (FanInRandomizer. (unchecked-negate boundary) boundary (if (nil? sqr-root?) false sqr-root?)))
:gaussian (fn [mean st-deviation] (GaussianRandomizer. mean st-deviation))
:nguyen-widrow (NguyenWidrowRandomizer.)
))

(defn randomize
"Performs the actual randomization. Expects a randomizer object and some data. Options for data include:
MLMethod -- double -- double[] -- double[][] -- Matrix "
[^Randomizer randomizer data]
(. randomizer randomize data))

(defmacro judge
"Consumer convenience for implementing the CalculateScore interface which is needed for genetic and simulated annealing training."
[minimize? & body]
Expand All @@ -51,7 +75,13 @@
(shouldMinimize [] ~minimize?)))

(defn make-trainer
"Constructs a training-method object."
"Constructs a training-method object given a method. Options inlude:
-------------------------------------------------------------
:simple :back-prop :quick-prop :manhattan
:genetic :svm :nelder-mead :annealing
:scaled-conjugent :resilient-prop :pnn
-------------------------------------------------------------
Returns a MLTrain object."
[method]
(condp = method
:simple (fn [net tr-set learn-rate] (TrainAdaline. net tr-set (if (nil? learn-rate) 2.0 learn-rate)))
Expand Down Expand Up @@ -85,18 +115,28 @@


(defn train
"Does the actual training. This is a potentially lengthy and costly process so most type hints have been provided. Returns true or false depending on whether the error target was met within the iteration limit."
[^MLTrain method ^Double error-tolerance ^Integer limit & strategies] ;;eg: (new RequiredImprovementStrategy 5)
"Does the actual training. This is a potentially lengthy and costly process so most type hints have been provided. Returns true or false depending on whether the error target was met within the iteration limit. This is an overloaded fucntion. It is up to you whether you want to provide limits for error-tolerance, iteration-number or both."
([^MLTrain method ^Double error-tolerance ^Integer limit strategies] ;;eg: (new RequiredImprovementStrategy 5)
(when (seq strategies) (dotimes [i (count strategies)]
(.addStrategy method (nth strategies i))))
(loop [epoch (int 1)]
(if (< limit epoch) false ;;failed to converge
(do (. method iteration)
(println "Epoch #" epoch " Error:" (. method getError))
(if-not (> (. method getError)
error-tolerance) true ;;succeeded to converge
(recur (inc epoch)))))))
(loop [epoch (int 1)]
(if (< limit epoch) false ;;failed to converge
(do (. method iteration)
(println "Epoch #" epoch " Error:" (. method getError))
(if-not (> (. method getError) error-tolerance) true ;;succeeded to converge
(recur (inc epoch)))))))

([^MLTrain method ^Double error-tolerance strategies]
(when (seq strategies) (dotimes [i (count strategies)]
(.addStrategy method (nth strategies i))))
(EncogUtility/trainToError method error-tolerance))

([^MLTrain method strategies] ;;requires only one iteration - SVMs or Nelder-Mead for example
(when (seq strategies) (dotimes [i (count strategies)]
(.addStrategy method (nth strategies i))))
(do (. method iteration)
(println "Error:" (. method getError)))))


(defn normalize
"Normalizes a seq (vector/list) of doubles (no nests) between high-end low-end and returns the normalized double array. Call seq on the result to convert it back to a clojure seq so it reads nicely on the repl."
Expand All @@ -105,7 +145,8 @@
(do (. norm setNormalizedHigh high-end)
(. norm setNormalizedLow low-end)
(. norm process (double-array data)))))






Expand Down

0 comments on commit 65b091d

Please sign in to comment.