Skip to content

Commit

Permalink
Redo plots for multiple values of eta, and add a second image of the …
Browse files Browse the repository at this point in the history
…gradient valley (this time without axis labels)
  • Loading branch information
mnielsen committed May 19, 2014
1 parent b5ee3ad commit 074671c
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 12 deletions.
2 changes: 1 addition & 1 deletion fig/multiple_eta.json
Original file line number Diff line number Diff line change
@@ -1 +1 @@
[[[], [], [43172.531738718622, 32607.45599125815, 28555.613320940582, 26223.019459805084, 24159.983453228011, 22842.160827228243, 21621.047507655829, 20898.172151309464, 19894.548289400343, 19178.117936905142, 18678.344114340609, 18125.735324308338, 17611.45436101771, 17287.666662665855, 16727.051368599092, 16415.475111081811, 16041.956738514162, 15709.779481248608, 15484.208231695791, 15294.436510794447, 14974.509277302383, 14801.619293212863, 14445.002057375092, 14279.559642422379, 14136.181827445431, 13903.073818770967, 13816.480322363819, 13595.195548292526, 13434.816999862142, 13153.711167719257], []], [[], [], [22223.657387371, 16896.443188536523, 16227.946535791749, 14174.094557968658, 13979.743988155209, 13620.22261993784, 12558.609809802616, 12536.331953998095, 11798.714492378163, 11699.198219592367, 11279.023385833596, 11638.625883655473, 11069.999725131602, 11638.970544434467, 11675.992533681885, 11100.949640822671, 10560.967141927669, 10693.213593158665, 10642.297918774468, 10569.552676543088, 10400.459096766863, 10305.715065674478, 9986.9725955374724, 10693.550467669651, 10916.098044629935, 10233.006856353355, 10283.159058328103, 10174.834475910666, 10457.710465303051, 10855.211856702586], []], [[], [], [31636.728121584703, 27640.439299218131, 30072.97015115829, 25572.283298173268, 24707.678800971247, 42144.315119767081, 25089.315759748908, 23797.484884104135, 22578.576025909235, 25585.167160429606, 22450.246455185665, 22490.784478277921, 25563.47618231974, 27937.678507188739, 25496.289434475395, 24897.89522640916, 25888.705592370072, 26610.525992113908, 26565.263730368992, 23579.526230927826, 32471.740188771659, 29948.455174825744, 28777.791150455723, 22267.045098921677, 26423.254503818986, 27656.769372739484, 26768.749142124725, 29470.289669647344, 29413.102443044823, 28516.878755790236], []]]
[[[], [], [0.87809508908377998, 0.67406552530098141, 0.59798920430275404, 0.55533015743656189, 0.51751101003208144, 0.4942033354556824, 0.47255041042913526, 0.46069879353359433, 0.44304475294352064, 0.43099562372228112, 0.42310993427766375, 0.41408265298981006, 0.40573464183982105, 0.40110722961828227, 0.39162028064538967, 0.38705015774740958, 0.38116357043417587, 0.37603986695304614, 0.37297012040237154, 0.37057334627661631, 0.36551756338853658, 0.36335674264586654, 0.35745296185579917, 0.35535960956849127, 0.35365591135061097, 0.35011353300568238, 0.34946519495897871, 0.34604661988238178, 0.34386077098862522, 0.33919980880230349], []], [[], [], [0.49501954654296704, 0.4063145129425576, 0.40482383242804637, 0.37156577828840276, 0.37380111172151681, 0.37152751786000143, 0.35371985224004426, 0.3557161388797867, 0.34323780090168027, 0.3433514311156789, 0.3367645441708797, 0.34532085892085329, 0.33506383267050244, 0.34760988079085842, 0.34921493732996928, 0.33853424834583179, 0.32837282561262077, 0.33175599401109612, 0.33132920379429243, 0.33024353325326034, 0.32736756892399654, 0.3259638557593546, 0.32004264784244907, 0.33424319076405928, 0.33878125802305081, 0.32521839878261177, 0.32679267619514646, 0.32488571435373748, 0.33056367198473002, 0.33879633130932685], []], [[], [], [0.92489293305102116, 0.83919130289246469, 0.88748421594232696, 0.79625231780396133, 0.78117959228699174, 1.1365919079387048, 0.78787239608336346, 0.76778614131217449, 0.73689525303227721, 0.80127437393519696, 0.74433665287336681, 0.73725544607013882, 0.80249602203179993, 0.85190338199210014, 0.79872168623645712, 0.80243104440756152, 0.80649160680410659, 0.81467254023600921, 0.82526467696100858, 0.75042379852601759, 0.93658673378777402, 0.88236662906752283, 0.86121396033520892, 0.72492681699401829, 0.80405009868466648, 0.83959963179208197, 0.83387510808276821, 0.88282498566307899, 0.88583473645177979, 0.86068501713490919], []]]
Binary file modified fig/multiple_eta.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
17 changes: 8 additions & 9 deletions fig/multiple_eta.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

# Standard library
import json
import random
import sys

# My library
Expand All @@ -20,13 +21,8 @@
import matplotlib.pyplot as plt
import numpy as np

# Make results more easily reproducible
np.random.seed(12345678)
import random
random.seed(12345678)

# Constants
LEARNING_RATES = [0.0025, 0.025, 0.25]
LEARNING_RATES = [0.025, 0.25, 2.5]
COLORS = ['#2A6EA6', '#FFCD33', '#FF7033']
NUM_EPOCHS = 30

Expand All @@ -40,15 +36,18 @@ def run_networks():
they can later be used by ``make_plot``.
"""
# Make results more easily reproducible
random.seed(12345678)
np.random.seed(12345678)
training_data, validation_data, test_data = mnist_loader.load_data_wrapper()
results = []
for eta in LEARNING_RATES:
print "\nTrain a network using eta = "+str(eta)
net = network2.Network([784, 30, 10])
results.append(
net.SGD(training_data, NUM_EPOCHS, 10, eta,
evaluation_data=validation_data, lmbda = 0.001,
monitor_training_cost=True))
net.SGD(training_data, NUM_EPOCHS, 10, eta, lmbda=5.0,
evaluation_data=validation_data,
monitor_training_cost=True))
f = open("multiple_eta.json", "w")
json.dump(results, f)
f.close()
Expand Down
3 changes: 1 addition & 2 deletions fig/valley.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,9 +36,8 @@
ax.w_xaxis.set_major_locator(LinearLocator(3))
ax.w_yaxis.set_major_locator(LinearLocator(3))
ax.w_zaxis.set_major_locator(LinearLocator(3))
ax.text(1.79, 0, 1.62, "$C$", fontsize=20)
ax.text(0.05, -1.8, 0, "$v_1$", fontsize=20)
ax.text(1.5, -0.25, 0, "$v_2$", fontsize=20)
ax.text(1.79, 0, 1.62, "$C$", fontsize=20)

plt.show()

Binary file added fig/valley2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
48 changes: 48 additions & 0 deletions fig/valley2.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
"""valley2.py
~~~~~~~~~~~~~
Plots a function of two variables to minimize. The function is a
fairly generic valley function.
Note that this is a duplicate of valley.py, but omits labels on the
axis. It's bad practice to duplicate in this way, but I had
considerable trouble getting matplotlib to update a graph in the way I
needed (adding or removing labels), so finally fell back on this as a
kludge solution.
"""

#### Libraries
# Third party libraries
from matplotlib.ticker import LinearLocator
# Note that axes3d is not explicitly used in the code, but is needed
# to register the 3d plot type correctly
from mpl_toolkits.mplot3d import axes3d
import matplotlib.pyplot as plt
import numpy

fig = plt.figure()
ax = fig.gca(projection='3d')
X = numpy.arange(-1, 1, 0.1)
Y = numpy.arange(-1, 1, 0.1)
X, Y = numpy.meshgrid(X, Y)
Z = X**2 + Y**2

colortuple = ('w', 'b')
colors = numpy.empty(X.shape, dtype=str)
for x in xrange(len(X)):
for y in xrange(len(Y)):
colors[x, y] = colortuple[(x + y) % 2]

surf = ax.plot_surface(X, Y, Z, rstride=1, cstride=1, facecolors=colors,
linewidth=0)

ax.set_xlim3d(-1, 1)
ax.set_ylim3d(-1, 1)
ax.set_zlim3d(0, 2)
ax.w_xaxis.set_major_locator(LinearLocator(3))
ax.w_yaxis.set_major_locator(LinearLocator(3))
ax.w_zaxis.set_major_locator(LinearLocator(3))
ax.text(1.79, 0, 1.62, "$C$", fontsize=20)

plt.show()

0 comments on commit 074671c

Please sign in to comment.