-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathneural_network_mnist.jl
43 lines (28 loc) · 977 Bytes
/
neural_network_mnist.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
# this is from https://github.com/FluxML/model-zoo/blob/master/vision/mnist/mlp.jl
# This is pretty slow so run at own risk
using Flux, Flux.Data.MNIST, Statistics
using Flux: onehotbatch, onecold, crossentropy, throttle
using Base.Iterators: repeated
# using CuArrays
# Classify MNIST digits with a simple multi-layer-perceptron
imgs = MNIST.images()
# Stack images into one large batch
X = hcat(float.(reshape.(imgs, :))...)
labels = MNIST.labels()
# One-hot-encode the labels
Y = onehotbatch(labels, 0:9)
m = Chain(
Dense(28^2, 32, relu),
Dense(32, 10),
softmax)
loss(x, y) = crossentropy(m(x), y)
accuracy(x, y) = mean(onecold(m(x)) .== onecold(y))
dataset = repeated((X, Y), 200)
evalcb = () -> @show(loss(X, Y))
opt = ADAM()
Flux.train!(loss, Flux.params(m), dataset, opt, cb = throttle(evalcb, 10))
accuracy(X, Y)
# Test set accuracy
tX = hcat(float.(reshape.(MNIST.images(:test), :))...)
tY = onehotbatch(MNIST.labels(:test), 0:9)
accuracy(tX, tY)