Skip to content

Commit

Permalink
Fix crash during printing of the profiler of optimizer.
Browse files Browse the repository at this point in the history
  • Loading branch information
nouiz committed Oct 7, 2013
1 parent 8c30d10 commit 9e8a326
Showing 1 changed file with 17 additions and 18 deletions.
35 changes: 17 additions & 18 deletions theano/gof/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -1424,15 +1424,12 @@ def apply(self, fgraph, start_from=None):

loop_timing = []
global_opt_timing = []
time_lopts = {}
time_opts = {}
io_toposort_timing = []
nb_nodes = []
for gopt in self.global_optimizers:
process_count.setdefault(gopt, 0)

for lopt in self.local_optimizers:
process_count.setdefault(lopt, 0)
time_lopts.setdefault(lopt, 0)
for opt in self.global_optimizers + self.local_optimizers:
process_count.setdefault(opt, 0)
time_opts.setdefault(opt, 0)

while changed and not max_use_abort:
t0 = time.time()
Expand All @@ -1441,7 +1438,9 @@ def apply(self, fgraph, start_from=None):
#apply global optimizers
for gopt in self.global_optimizers:
fgraph.change_tracker.reset()
t_opt = time.time()
gopt.apply(fgraph)
time_opts[gopt] += time.time() - t_opt
if fgraph.change_tracker.changed:
process_count[gopt] += 1
changed = True
Expand Down Expand Up @@ -1482,9 +1481,9 @@ def pruner(node):
current_node = node

for lopt in self.local_optimizers:
t_lopt = time.time()
t_opt = time.time()
lopt_change = self.process_node(fgraph, node, lopt)
time_lopts[lopt] += time.time() - t_lopt
time_opts[lopt] += time.time() - t_opt
if lopt_change:
process_count[lopt] += 1
changed = True
Expand All @@ -1507,7 +1506,7 @@ def pruner(node):
config.optdb.max_use_ratio)

return (self, loop_timing, process_count, max_nb_nodes,
global_opt_timing, nb_nodes, time_lopts, io_toposort_timing)
global_opt_timing, nb_nodes, time_opts, io_toposort_timing)

def print_summary(self, stream=sys.stdout, level=0, depth=-1):
name = getattr(self, 'name', None)
Expand All @@ -1521,7 +1520,7 @@ def print_summary(self, stream=sys.stdout, level=0, depth=-1):
@staticmethod
def print_profile(stream, prof, level=0):
(opt, loop_timing, process_count, max_nb_nodes,
global_opt_timing, nb_nodes, time_lopts, io_toposort_timing) = prof
global_opt_timing, nb_nodes, time_opts, io_toposort_timing) = prof
blanc = (' ' * level)
print >> stream, blanc, "EquilibriumOptimizer",
print >> stream, blanc, getattr(opt, "name",
Expand All @@ -1540,7 +1539,7 @@ def print_profile(stream, prof, level=0):
count_opt = []
for opt, count in process_count.iteritems():
if count > 0:
count_opt.append((time_lopts[opt], count, opt))
count_opt.append((time_opts[opt], count, opt))

if count_opt:
print >> stream, blanc, \
Expand All @@ -1554,7 +1553,7 @@ def print_profile(stream, prof, level=0):
@staticmethod
def merge_profile(prof1, prof2):
#(opt, loop_timing, process_count, max_nb_nodes,
# global_opt_timing, nb_nodes, time_lopts, io_toposort_timing) = prof1
# global_opt_timing, nb_nodes, time_opts, io_toposort_timing) = prof1

local_optimizers = set(prof1[0].local_optimizers).union(
prof2[0].local_optimizers)
Expand Down Expand Up @@ -1588,12 +1587,12 @@ def merge_list(l1, l2):

nb_nodes = merge_list(prof1[5], prof2[5])

time_lopts = prof1[6].copy()
time_opts = prof1[6].copy()
for opt, t in prof2[6].iteritems():
if opt in time_lopts:
time_lopts[opt] += t
if opt in time_opts:
time_opts[opt] += t
else:
time_lopts[opt] = t
time_opts[opt] = t

io_toposort_timing = merge_list(prof1[7], prof2[7])

Expand All @@ -1606,7 +1605,7 @@ def merge_list(l1, l2):
max_nb_nodes,
global_opt_timing,
nb_nodes,
time_lopts,
time_opts,
io_toposort_timing)

#################
Expand Down

0 comments on commit 9e8a326

Please sign in to comment.