Skip to content

Commit

Permalink
Add the opt tag name inplace_elemwise_optimizer to inplace_opt as this
Browse files Browse the repository at this point in the history
is the variable name used.

I lost 1h before undersanding it isn't the name used in the optdb.

Also test that we can disable inplace and fusion opt with MonitorMode.
  • Loading branch information
nouiz committed Jun 10, 2013
1 parent 554b4b5 commit ff0ac3c
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 3 deletions.
39 changes: 37 additions & 2 deletions theano/compile/tests/test_monitormode.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,9 @@ def detect_nan(i, node, fn):
assert nan_detected[0]


def test_optimizers():
def test_optimizer():
"""
Test that we can remove optimizers
Test that we can remove optimizer
"""
nan_detected = [False]

Expand All @@ -54,3 +54,38 @@ def detect_nan(i, node, fn):

# Test that we still detect the nan
assert nan_detected[0]


def test_not_inplace():
"""
Test that we can remove optimizers including inplace optimizers
"""
nan_detected = [False]

def detect_nan(i, node, fn):
for output in fn.outputs:
if numpy.isnan(output[0]).any():
print '*** NaN detected ***'
theano.printing.debugprint(node)
print 'Inputs : %s' % [input[0] for input in fn.inputs]
print 'Outputs: %s' % [output[0] for output in fn.outputs]
nan_detected[0] = True
break

x = theano.tensor.vector('x')
mode = theano.compile.MonitorMode(post_func=detect_nan)
#mode = mode.excluding('fusion', 'inplace')
mode = mode.excluding('local_elemwise_fusion',
'inplace_elemwise_optimizer')
o = theano.tensor.outer(x, x)
out = theano.tensor.log(o) * o
f = theano.function([x], [out],
mode=mode)

# Test that the fusion wasn't done
assert len(f.maker.fgraph.nodes) == 5
assert not f.maker.fgraph.toposort()[-1].op.destroy_map
f([0, 0]) # log(0) * 0 = -inf * 0 = NaN

# Test that we still detect the nan
assert nan_detected[0]
2 changes: 1 addition & 1 deletion theano/tensor/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,8 @@ def inplace_elemwise_optimizer(fgraph):
return inplace_elemwise_optimizer

inplace_elemwise_optimizer = inplace_elemwise_optimizer_op(T.Elemwise)

compile.optdb.register('inplace_opt', inplace_elemwise_optimizer, 75,
'inplace_elemwise_optimizer',
'fast_run', 'inplace')


Expand Down

0 comments on commit ff0ac3c

Please sign in to comment.