Skip to content

Commit

Permalink
Small fix to pydotprint_variables
Browse files Browse the repository at this point in the history
  • Loading branch information
nouiz committed Oct 23, 2014
1 parent bbf5269 commit c331362
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 1 deletion.
3 changes: 2 additions & 1 deletion theano/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -992,7 +992,7 @@ def plot_apply(app, d):
if nd.owner:
plot_apply(nd.owner, depth)
try:
g.write_png(outfile, prog='dot')
g.write(outfile, prog='dot', format=format)
except pd.InvocationException, e:
# Some version of pydot are bugged/don't work correctly with
# empty label. Provide a better user error message.
Expand All @@ -1006,6 +1006,7 @@ def plot_apply(app, d):
" Theano. Using another version of pydot could"
" fix this problem. The pydot error is: " +
e.message)
raise

print 'The output file is available at', outfile

Expand Down
32 changes: 32 additions & 0 deletions theano/tests/test_printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from theano.printing import min_informative_str, debugprint
from theano.compat.six import StringIO


def test_pydotprint_cond_highlight():
"""
This is a REALLY PARTIAL TEST.
Expand Down Expand Up @@ -44,6 +45,37 @@ def test_pydotprint_cond_highlight():
' is no IfElse node in the graph\n')


def test_pydotprint_variables():
"""
This is a REALLY PARTIAL TEST.
I did them to help debug stuff.
It make sure the code run.
"""

# Skip test if pydot is not available.
if not theano.printing.pydot_imported:
raise SkipTest('pydot not available')

x = tensor.dvector()

s = StringIO()
new_handler = logging.StreamHandler(s)
new_handler.setLevel(logging.DEBUG)
orig_handler = theano.logging_default_handler

theano.theano_logger.removeHandler(orig_handler)
theano.theano_logger.addHandler(new_handler)
theano.theano_logger.removeHandler(orig_handler)
theano.theano_logger.addHandler(new_handler)
try:
theano.printing.pydotprint_variables(x * 2)
finally:
theano.theano_logger.addHandler(orig_handler)
theano.theano_logger.removeHandler(new_handler)


def test_pydotprint_long_name():
"""This is a REALLY PARTIAL TEST.
Expand Down

0 comments on commit c331362

Please sign in to comment.