Skip to content

Commit 0cb1607

Browse files
add adam example
1 parent 945295e commit 0cb1607

File tree

1 file changed

+204
-0
lines changed

1 file changed

+204
-0
lines changed

ann_class2/adam.py

Lines changed: 204 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,204 @@
1+
# Compare RMSprop with momentum vs. Adam
2+
# For the class Data Science: Practical Deep Learning Concepts in Theano and TensorFlow
3+
# https://deeplearningcourses.com/c/data-science-deep-learning-in-theano-tensorflow
4+
# https://www.udemy.com/data-science-deep-learning-in-theano-tensorflow
5+
from __future__ import print_function, division
6+
from builtins import range
7+
# Note: you may need to update your version of future
8+
# sudo pip install -U future
9+
10+
import numpy as np
11+
from sklearn.utils import shuffle
12+
import matplotlib.pyplot as plt
13+
14+
from util import get_normalized_data, error_rate, cost, y2indicator
15+
from mlp import forward, derivative_w2, derivative_w1, derivative_b2, derivative_b1
16+
17+
18+
def main():
19+
max_iter = 10
20+
print_period = 10
21+
22+
X, Y = get_normalized_data()
23+
reg = 0.01
24+
25+
Xtrain = X[:-1000,]
26+
Ytrain = Y[:-1000]
27+
Xtest = X[-1000:,]
28+
Ytest = Y[-1000:]
29+
Ytrain_ind = y2indicator(Ytrain)
30+
Ytest_ind = y2indicator(Ytest)
31+
32+
N, D = Xtrain.shape
33+
batch_sz = 500
34+
n_batches = N // batch_sz
35+
36+
M = 300
37+
K = 10
38+
W1_0 = np.random.randn(D, M) / np.sqrt(D)
39+
b1_0 = np.zeros(M)
40+
W2_0 = np.random.randn(M, K) / np.sqrt(M)
41+
b2_0 = np.zeros(K)
42+
43+
W1 = W1_0.copy()
44+
b1 = b1_0.copy()
45+
W2 = W2_0.copy()
46+
b2 = b2_0.copy()
47+
48+
# 1st moment
49+
mW1 = 0
50+
mb1 = 0
51+
mW2 = 0
52+
mb2 = 0
53+
54+
# 2nd moment
55+
vW1 = 0
56+
vb1 = 0
57+
vW2 = 0
58+
vb2 = 0
59+
60+
# hyperparams
61+
lr0 = 0.001
62+
beta1 = 0.9
63+
beta2 = 0.999
64+
eps = 1e-8
65+
66+
# 1. Adam
67+
loss_adam = []
68+
err_adam = []
69+
t = 1
70+
for i in range(max_iter):
71+
for j in range(n_batches):
72+
Xbatch = Xtrain[j*batch_sz:(j*batch_sz + batch_sz),]
73+
Ybatch = Ytrain_ind[j*batch_sz:(j*batch_sz + batch_sz),]
74+
pYbatch, Z = forward(Xbatch, W1, b1, W2, b2)
75+
76+
# updates
77+
# gradients
78+
gW2 = derivative_w2(Z, Ybatch, pYbatch) + reg*W2
79+
gb2 = derivative_b2(Ybatch, pYbatch) + reg*b2
80+
gW1 = derivative_w1(Xbatch, Z, Ybatch, pYbatch, W2) + reg*W1
81+
gb1 = derivative_b1(Z, Ybatch, pYbatch, W2) + reg*b1
82+
83+
# new m
84+
mW1 = beta1 * mW1 + (1 - beta1) * gW1
85+
mb1 = beta1 * mb1 + (1 - beta1) * gb1
86+
mW2 = beta1 * mW2 + (1 - beta1) * gW2
87+
mb2 = beta1 * mb2 + (1 - beta1) * gb2
88+
89+
# new v
90+
vW1 = beta2 * vW1 + (1 - beta2) * gW1 * gW1
91+
vb1 = beta2 * vb1 + (1 - beta2) * gb1 * gb1
92+
vW2 = beta2 * vW2 + (1 - beta2) * gW2 * gW2
93+
vb2 = beta2 * vb2 + (1 - beta2) * gb2 * gb2
94+
95+
# bias correction
96+
correction1 = 1 - beta1 ** t
97+
hat_mW1 = mW1 / correction1
98+
hat_mb1 = mb1 / correction1
99+
hat_mW2 = mW2 / correction1
100+
hat_mb2 = mb2 / correction1
101+
102+
correction2 = 1 - beta2 ** t
103+
hat_vW1 = vW1 / correction2
104+
hat_vb1 = vb1 / correction2
105+
hat_vW2 = vW2 / correction2
106+
hat_vb2 = vb2 / correction2
107+
108+
# update t
109+
t += 1
110+
111+
# apply updates to the params
112+
W1 = W1 - lr0 * hat_mW1 / np.sqrt(hat_vW1 + eps)
113+
b1 = b1 - lr0 * hat_mb1 / np.sqrt(hat_vb1 + eps)
114+
W2 = W2 - lr0 * hat_mW2 / np.sqrt(hat_vW2 + eps)
115+
b2 = b2 - lr0 * hat_mb2 / np.sqrt(hat_vb2 + eps)
116+
117+
118+
if j % print_period == 0:
119+
pY, _ = forward(Xtest, W1, b1, W2, b2)
120+
l = cost(pY, Ytest_ind)
121+
loss_adam.append(l)
122+
print("Cost at iteration i=%d, j=%d: %.6f" % (i, j, l))
123+
124+
err = error_rate(pY, Ytest)
125+
err_adam.append(err)
126+
print("Error rate:", err)
127+
128+
pY, _ = forward(Xtest, W1, b1, W2, b2)
129+
print("Final error rate:", error_rate(pY, Ytest))
130+
131+
132+
# 2. RMSprop with momentum
133+
W1 = W1_0.copy()
134+
b1 = b1_0.copy()
135+
W2 = W2_0.copy()
136+
b2 = b2_0.copy()
137+
loss_rms = []
138+
err_rms = []
139+
140+
lr0 = 0.001 # if you set this too high you'll get NaN!
141+
mu = 0.9
142+
decay_rate = 0.999
143+
eps = 1e-8
144+
145+
# rmsprop cache
146+
cache_W2 = 1
147+
cache_b2 = 1
148+
cache_W1 = 1
149+
cache_b1 = 1
150+
151+
# momentum
152+
dW1 = 0
153+
db1 = 0
154+
dW2 = 0
155+
db2 = 0
156+
157+
for i in range(max_iter):
158+
for j in range(n_batches):
159+
Xbatch = Xtrain[j*batch_sz:(j*batch_sz + batch_sz),]
160+
Ybatch = Ytrain_ind[j*batch_sz:(j*batch_sz + batch_sz),]
161+
pYbatch, Z = forward(Xbatch, W1, b1, W2, b2)
162+
163+
# updates
164+
gW2 = derivative_w2(Z, Ybatch, pYbatch) + reg*W2
165+
cache_W2 = decay_rate*cache_W2 + (1 - decay_rate)*gW2*gW2
166+
dW2 = mu * dW2 - (1 - mu) * lr0 * gW2 / (np.sqrt(cache_W2) + eps)
167+
W2 += dW2
168+
169+
gb2 = derivative_b2(Ybatch, pYbatch) + reg*b2
170+
cache_b2 = decay_rate*cache_b2 + (1 - decay_rate)*gb2*gb2
171+
db2 = mu * db2 - (1 - mu) * lr0 * gb2 / (np.sqrt(cache_b2) + eps)
172+
b2 += db2
173+
174+
gW1 = derivative_w1(Xbatch, Z, Ybatch, pYbatch, W2) + reg*W1
175+
cache_W1 = decay_rate*cache_W1 + (1 - decay_rate)*gW1*gW1
176+
dW1 = mu * dW1 - (1 - mu) * lr0 * gW1 / (np.sqrt(cache_W1) + eps)
177+
W1 += dW1
178+
179+
gb1 = derivative_b1(Z, Ybatch, pYbatch, W2) + reg*b1
180+
cache_b1 = decay_rate*cache_b1 + (1 - decay_rate)*gb1*gb1
181+
db1 = mu * db1 - (1 - mu) * lr0 * gb1 / (np.sqrt(cache_b1) + eps)
182+
b1 += db1
183+
184+
if j % print_period == 0:
185+
pY, _ = forward(Xtest, W1, b1, W2, b2)
186+
l = cost(pY, Ytest_ind)
187+
loss_rms.append(l)
188+
print("Cost at iteration i=%d, j=%d: %.6f" % (i, j, l))
189+
190+
err = error_rate(pY, Ytest)
191+
err_rms.append(err)
192+
print("Error rate:", err)
193+
194+
pY, _ = forward(Xtest, W1, b1, W2, b2)
195+
print("Final error rate:", error_rate(pY, Ytest))
196+
197+
plt.plot(loss_adam, label='adam')
198+
plt.plot(loss_rms, label='rmsprop')
199+
plt.legend()
200+
plt.show()
201+
202+
203+
if __name__ == '__main__':
204+
main()

0 commit comments

Comments
 (0)