Skip to content

Commit

Permalink
Various optimizations
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewearl committed Apr 12, 2016
1 parent 3294567 commit beb184a
Showing 1 changed file with 67 additions and 8 deletions.
75 changes: 67 additions & 8 deletions train.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import functools
import glob
import itertools
import multiprocessing
import random
import Queue
import sys
import threading

import cv2
import numpy
Expand Down Expand Up @@ -151,10 +155,67 @@ def batch(it, batch_size):
yield out


def mpgen(f):
def main(q, args, kwargs):
try:
for item in f(*args, **kwargs):
q.put(item)
finally:
print "Closing queue"
q.close()

@functools.wraps(f)
def wrapped(*args, **kwargs):
q = multiprocessing.Queue(3)
proc = multiprocessing.Process(target=main,
args=(q, args, kwargs))
proc.start()
try:
while True:
item = q.get()
yield item
finally:
print "Waiting for proc to exit"
proc.terminate()
proc.join()

return wrapped


def threadgen(f):
@functools.wraps(f)
def wrapped(*args, **kwargs):
def main():
for item in f(*args, **kwargs):
if stop:
break
q.put(item)

q = Queue.Queue(3)
stop = False
thr = threading.Thread(target=main)
thr.start()
try:
while True:
item = q.get()
yield item
finally:
print "Waiting for proc to exit"
stop = True
try:
q.get_nowait()
except Queue.Empty:
pass
thr.join()

return wrapped

@mpgen
def read_batches(batch_size):
def gen_vecs():
for im, c, p in gen.generate_ims(batch_size, bg_prob=0.0):
yield im_to_vec(im), code_to_vec(p, c)

while True:
yield unzip(gen_vecs())

Expand Down Expand Up @@ -300,11 +361,12 @@ def do_report():

def do_batch():
sess.run(train_step,
feed_dict={x: batch_xs, y_: batch_ys})
feed_dict={x: batch_xs, y_: batch_ys})
if batch_idx % report_steps == 0:
do_report()

with tf.Session() as sess:
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=0.7)
with tf.Session(config=tf.ConfigProto(gpu_options=gpu_options)) as sess:
sess.run(init)
if initial_weights is not None:
sess.run(assign_ops)
Expand All @@ -314,12 +376,9 @@ def do_batch():
try:
batch_iter = enumerate(read_batches(batch_size))
for batch_idx, (batch_xs, batch_ys) in batch_iter:
weights = [p.eval() for p in params]
if all(numpy.all(numpy.isnan(w)) for w in weights):
raise WeightsWentNan
last_weights = weights
do_batch()
except KeyboardInterrupt, WeightsWentNan:
except KeyboardInterrupt:
last_weights = [p.eval() for p in params]
numpy.savez("weights.npz", *last_weights)


Expand All @@ -338,7 +397,7 @@ def do_batch():
bg_prob=0.5,
initial_weights=initial_weights)
elif sys.argv[1] == "read":
train_reader(learn_rate=0.0001,
train_reader(learn_rate=0.001,
report_steps=20,
batch_size=50,
initial_weights=initial_weights)
Expand Down

0 comments on commit beb184a

Please sign in to comment.