Skip to content

Commit

Permalink
Merge pull request Theano#1549 from lamblin/fix_shape_cycle
Browse files Browse the repository at this point in the history
[WIP] Fix shape cycle
  • Loading branch information
nouiz committed Oct 7, 2013
2 parents 361ca86 + 7dae6a7 commit 15ad6f3
Showing 1 changed file with 20 additions and 7 deletions.
27 changes: 20 additions & 7 deletions theano/tensor/opt.py
Original file line number Diff line number Diff line change
Expand Up @@ -898,22 +898,29 @@ def update_shape(self, r, other_r):
r_shape = self.shape_of[r]
else:
# If no info is known on r's shape, use other_shape
self.shape_of[r] = other_shape
for sv in other_shape:
self.shape_of_reverse_index.setdefault(sv, set()).add(r)
self.set_shape(r, other_shape)
return

# Merge other_shape with r_shape, giving the priority to other_shape
merged_shape = []
for i, ps in enumerate(other_shape):
# If other_shape[i] is uninformative, use r_shape[i].
# For now, we consider 2 cases of uninformative other_shape[i]:
# - Shape_i(i)(other_r);
# - Shape_i(i)(r).
if (ps.owner
and isinstance(getattr(ps.owner, 'op', None), Shape_i)
and ps.owner.op.i == i
and ps.owner.inputs[0] in (r, other_r)):
# If other_shape[i] is uninformative, use r_shape[i].
# For now, we consider 2 cases of uninformative other_shape[i]:
# - Shape_i(i)(other_r);
# - Shape_i(i)(r).
merged_shape.append(r_shape[i])
elif r_shape[i] in theano.gof.graph.ancestors([other_shape[i]]):
# Another case where we want to use r_shape[i] is when
# other_shape[i] actually depends on r_shape[i]. In that case,
# we do not want to substitute an expression with another that
# is strictly more complex. Such a substitution could also lead
# to cycles: if (in the future) r_shape[i] gets replaced by an
# expression of other_shape[i], other_shape[i] may end up
# depending on itself.
merged_shape.append(r_shape[i])
else:
merged_shape.append(other_shape[i])
Expand Down Expand Up @@ -1107,6 +1114,12 @@ def on_change_input(self, fgraph, node, i, r, new_r):
# replacement.
continue

if shpnode.outputs[0] in theano.gof.graph.ancestors([repl]):
raise AssertionError(
"This substitution would insert a cycle in the graph:"
"node: %s, i: %i, r: %s, new_r: %s"
% (node, i, r, new_r))

self.scheduled[shpnode] = new_r
# In case 2, if r is a variable that we've scheduled for shape update,
# then we should cancel it.
Expand Down

0 comments on commit 15ad6f3

Please sign in to comment.