Skip to content

Commit

Permalink
add test for an optimization about scan sequences
Browse files Browse the repository at this point in the history
  • Loading branch information
nouiz committed Oct 7, 2013
1 parent c1ebebc commit 4c30389
Showing 1 changed file with 52 additions and 9 deletions.
61 changes: 52 additions & 9 deletions theano/scan_module/tests/test_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -3564,11 +3564,10 @@ def test_scan_merge_nodes(self):
assert not opt_obj.belongs_to_set(scan_node1, [scan_node2])
assert not opt_obj.belongs_to_set(scan_node2, [scan_node1])

def test_remove_constants_and_unused_inputs_scan(self):
"""
Test the opt remove_constants_and_unused_inputs_scan
def test_remove_constants_and_unused_inputs_scan_non_seqs(self):
"""Test the opt remove_constants_and_unused_inputs_scan for
non sequences.
TODO: currently we only test non_seqs, should test
"""
W = theano.tensor.matrix(name='W')
v = theano.tensor.ivector(name='v')
Expand All @@ -3594,17 +3593,61 @@ def test_remove_constants_and_unused_inputs_scan(self):
f(numpy.zeros((3, 3), dtype=theano.config.floatX), [1, 2])
scan_node = f.maker.fgraph.toposort()[-1]

# TODO: Why this assert always fail?
# assert (len(scan_node.inputs) ==
# len(set(scan_node.inputs)))
# The first input is the number of iteration.
assert (len(scan_node.inputs[1:]) ==
len(set(scan_node.inputs[1:])))
inp = scan_node.op.inner_non_seqs(scan_node.op.inputs)
assert len(inp) == 1
assert (len(inp) == len(set(inp)))
inp = scan_node.op.outer_non_seqs(scan_node)
assert len(inp) == 1
assert (len(inp) == len(set(inp)))
#import pdb;pdb.set_trace()
#utt.assert_allclose(f([1, 2]), [[0, 0, 0], [1, 1, 1], [1, 1, 1]])

def test_remove_constants_and_unused_inputs_scan_seqs(self):
"""
Test the opt remove_constants_and_unused_inputs_scan for sequences.
"""
W = theano.tensor.matrix(name='W')
v = theano.tensor.ivector(name='v')
vv = theano.tensor.matrix(name='vv')
y1, _ = theano.scan(lambda i, W: W[i], sequences=v,
outputs_info=None, non_sequences=[W])
y2, _ = theano.scan(lambda i, _, W: W[i], sequences=[v, v],
outputs_info=None, non_sequences=W)
y3, _ = theano.scan(lambda i, _, W: W[i], sequences=[v, vv[0]],
outputs_info=None, non_sequences=W)
y4, _ = theano.scan(lambda _, i, W: W[i], sequences=[vv[0], v],
outputs_info=None, non_sequences=W)
y5, _ = theano.scan(lambda _, i, _2, W: W[i], sequences=[vv, v, vv[0]],
outputs_info=None, non_sequences=W)
y6, _ = theano.scan(lambda _, _2, i, W: W[i], sequences=[vv[0], vv, v],
outputs_info=None, non_sequences=W)
y7, _ = theano.scan(lambda i, _, _2, W: W[i],
sequences=[v, vv[0], vv[0]],
outputs_info=None, non_sequences=W)
y8, _ = theano.scan(lambda _, i, W, _2, _3: W[i], sequences=[vv[0], v],
outputs_info=None, non_sequences=[W, W[0], W[0]])
for out in [y1, y2, y3, y4, y5, y6, y7, y8]:
#This used to raise an exception
f = theano.function([W, v, vv], out, on_unused_input='ignore',
mode=mode_with_opt)
f(numpy.zeros((3, 3), theano.config.floatX),
[1, 2],
numpy.zeros((3, 3), theano.config.floatX))
scan_node = f.maker.fgraph.toposort()[-1]

# The first input is the number of iteration.
assert (len(scan_node.inputs[1:]) ==
len(set(scan_node.inputs[1:])))
inp = scan_node.op.inner_seqs(scan_node.op.inputs)
assert len(inp) == 1
inp = scan_node.op.outer_seqs(scan_node)
assert len(inp) == 1
inp = scan_node.op.inner_non_seqs(scan_node.op.inputs)
assert len(inp) == 1
inp = scan_node.op.outer_non_seqs(scan_node)
assert len(inp) == 1


def test_speed():
Expand Down

0 comments on commit 4c30389

Please sign in to comment.