1+ # Compare RMSprop with momentum vs. Adam
2+ # For the class Data Science: Practical Deep Learning Concepts in Theano and TensorFlow
3+ # https://deeplearningcourses.com/c/data-science-deep-learning-in-theano-tensorflow
4+ # https://www.udemy.com/data-science-deep-learning-in-theano-tensorflow
5+ from __future__ import print_function , division
6+ from builtins import range
7+ # Note: you may need to update your version of future
8+ # sudo pip install -U future
9+
10+ import numpy as np
11+ from sklearn .utils import shuffle
12+ import matplotlib .pyplot as plt
13+
14+ from util import get_normalized_data , error_rate , cost , y2indicator
15+ from mlp import forward , derivative_w2 , derivative_w1 , derivative_b2 , derivative_b1
16+
17+
18+ def main ():
19+ max_iter = 10
20+ print_period = 10
21+
22+ X , Y = get_normalized_data ()
23+ reg = 0.01
24+
25+ Xtrain = X [:- 1000 ,]
26+ Ytrain = Y [:- 1000 ]
27+ Xtest = X [- 1000 :,]
28+ Ytest = Y [- 1000 :]
29+ Ytrain_ind = y2indicator (Ytrain )
30+ Ytest_ind = y2indicator (Ytest )
31+
32+ N , D = Xtrain .shape
33+ batch_sz = 500
34+ n_batches = N // batch_sz
35+
36+ M = 300
37+ K = 10
38+ W1_0 = np .random .randn (D , M ) / np .sqrt (D )
39+ b1_0 = np .zeros (M )
40+ W2_0 = np .random .randn (M , K ) / np .sqrt (M )
41+ b2_0 = np .zeros (K )
42+
43+ W1 = W1_0 .copy ()
44+ b1 = b1_0 .copy ()
45+ W2 = W2_0 .copy ()
46+ b2 = b2_0 .copy ()
47+
48+ # 1st moment
49+ mW1 = 0
50+ mb1 = 0
51+ mW2 = 0
52+ mb2 = 0
53+
54+ # 2nd moment
55+ vW1 = 0
56+ vb1 = 0
57+ vW2 = 0
58+ vb2 = 0
59+
60+ # hyperparams
61+ lr0 = 0.001
62+ beta1 = 0.9
63+ beta2 = 0.999
64+ eps = 1e-8
65+
66+ # 1. Adam
67+ loss_adam = []
68+ err_adam = []
69+ t = 1
70+ for i in range (max_iter ):
71+ for j in range (n_batches ):
72+ Xbatch = Xtrain [j * batch_sz :(j * batch_sz + batch_sz ),]
73+ Ybatch = Ytrain_ind [j * batch_sz :(j * batch_sz + batch_sz ),]
74+ pYbatch , Z = forward (Xbatch , W1 , b1 , W2 , b2 )
75+
76+ # updates
77+ # gradients
78+ gW2 = derivative_w2 (Z , Ybatch , pYbatch ) + reg * W2
79+ gb2 = derivative_b2 (Ybatch , pYbatch ) + reg * b2
80+ gW1 = derivative_w1 (Xbatch , Z , Ybatch , pYbatch , W2 ) + reg * W1
81+ gb1 = derivative_b1 (Z , Ybatch , pYbatch , W2 ) + reg * b1
82+
83+ # new m
84+ mW1 = beta1 * mW1 + (1 - beta1 ) * gW1
85+ mb1 = beta1 * mb1 + (1 - beta1 ) * gb1
86+ mW2 = beta1 * mW2 + (1 - beta1 ) * gW2
87+ mb2 = beta1 * mb2 + (1 - beta1 ) * gb2
88+
89+ # new v
90+ vW1 = beta2 * vW1 + (1 - beta2 ) * gW1 * gW1
91+ vb1 = beta2 * vb1 + (1 - beta2 ) * gb1 * gb1
92+ vW2 = beta2 * vW2 + (1 - beta2 ) * gW2 * gW2
93+ vb2 = beta2 * vb2 + (1 - beta2 ) * gb2 * gb2
94+
95+ # bias correction
96+ correction1 = 1 - beta1 ** t
97+ hat_mW1 = mW1 / correction1
98+ hat_mb1 = mb1 / correction1
99+ hat_mW2 = mW2 / correction1
100+ hat_mb2 = mb2 / correction1
101+
102+ correction2 = 1 - beta2 ** t
103+ hat_vW1 = vW1 / correction2
104+ hat_vb1 = vb1 / correction2
105+ hat_vW2 = vW2 / correction2
106+ hat_vb2 = vb2 / correction2
107+
108+ # update t
109+ t += 1
110+
111+ # apply updates to the params
112+ W1 = W1 - lr0 * hat_mW1 / np .sqrt (hat_vW1 + eps )
113+ b1 = b1 - lr0 * hat_mb1 / np .sqrt (hat_vb1 + eps )
114+ W2 = W2 - lr0 * hat_mW2 / np .sqrt (hat_vW2 + eps )
115+ b2 = b2 - lr0 * hat_mb2 / np .sqrt (hat_vb2 + eps )
116+
117+
118+ if j % print_period == 0 :
119+ pY , _ = forward (Xtest , W1 , b1 , W2 , b2 )
120+ l = cost (pY , Ytest_ind )
121+ loss_adam .append (l )
122+ print ("Cost at iteration i=%d, j=%d: %.6f" % (i , j , l ))
123+
124+ err = error_rate (pY , Ytest )
125+ err_adam .append (err )
126+ print ("Error rate:" , err )
127+
128+ pY , _ = forward (Xtest , W1 , b1 , W2 , b2 )
129+ print ("Final error rate:" , error_rate (pY , Ytest ))
130+
131+
132+ # 2. RMSprop with momentum
133+ W1 = W1_0 .copy ()
134+ b1 = b1_0 .copy ()
135+ W2 = W2_0 .copy ()
136+ b2 = b2_0 .copy ()
137+ loss_rms = []
138+ err_rms = []
139+
140+ lr0 = 0.001 # if you set this too high you'll get NaN!
141+ mu = 0.9
142+ decay_rate = 0.999
143+ eps = 1e-8
144+
145+ # rmsprop cache
146+ cache_W2 = 1
147+ cache_b2 = 1
148+ cache_W1 = 1
149+ cache_b1 = 1
150+
151+ # momentum
152+ dW1 = 0
153+ db1 = 0
154+ dW2 = 0
155+ db2 = 0
156+
157+ for i in range (max_iter ):
158+ for j in range (n_batches ):
159+ Xbatch = Xtrain [j * batch_sz :(j * batch_sz + batch_sz ),]
160+ Ybatch = Ytrain_ind [j * batch_sz :(j * batch_sz + batch_sz ),]
161+ pYbatch , Z = forward (Xbatch , W1 , b1 , W2 , b2 )
162+
163+ # updates
164+ gW2 = derivative_w2 (Z , Ybatch , pYbatch ) + reg * W2
165+ cache_W2 = decay_rate * cache_W2 + (1 - decay_rate )* gW2 * gW2
166+ dW2 = mu * dW2 - (1 - mu ) * lr0 * gW2 / (np .sqrt (cache_W2 ) + eps )
167+ W2 += dW2
168+
169+ gb2 = derivative_b2 (Ybatch , pYbatch ) + reg * b2
170+ cache_b2 = decay_rate * cache_b2 + (1 - decay_rate )* gb2 * gb2
171+ db2 = mu * db2 - (1 - mu ) * lr0 * gb2 / (np .sqrt (cache_b2 ) + eps )
172+ b2 += db2
173+
174+ gW1 = derivative_w1 (Xbatch , Z , Ybatch , pYbatch , W2 ) + reg * W1
175+ cache_W1 = decay_rate * cache_W1 + (1 - decay_rate )* gW1 * gW1
176+ dW1 = mu * dW1 - (1 - mu ) * lr0 * gW1 / (np .sqrt (cache_W1 ) + eps )
177+ W1 += dW1
178+
179+ gb1 = derivative_b1 (Z , Ybatch , pYbatch , W2 ) + reg * b1
180+ cache_b1 = decay_rate * cache_b1 + (1 - decay_rate )* gb1 * gb1
181+ db1 = mu * db1 - (1 - mu ) * lr0 * gb1 / (np .sqrt (cache_b1 ) + eps )
182+ b1 += db1
183+
184+ if j % print_period == 0 :
185+ pY , _ = forward (Xtest , W1 , b1 , W2 , b2 )
186+ l = cost (pY , Ytest_ind )
187+ loss_rms .append (l )
188+ print ("Cost at iteration i=%d, j=%d: %.6f" % (i , j , l ))
189+
190+ err = error_rate (pY , Ytest )
191+ err_rms .append (err )
192+ print ("Error rate:" , err )
193+
194+ pY , _ = forward (Xtest , W1 , b1 , W2 , b2 )
195+ print ("Final error rate:" , error_rate (pY , Ytest ))
196+
197+ plt .plot (loss_adam , label = 'adam' )
198+ plt .plot (loss_rms , label = 'rmsprop' )
199+ plt .legend ()
200+ plt .show ()
201+
202+
203+ if __name__ == '__main__' :
204+ main ()
0 commit comments