Skip to content

Commit

Permalink
Remove AttributeError trap (lanpa#241)
Browse files Browse the repository at this point in the history
* Remove AttributeError trap

* Trap AttributeError in run_pass

* One more fix
  • Loading branch information
orionr authored and lanpa committed Oct 9, 2018
1 parent ddca1a9 commit 4093ead
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 9 deletions.
9 changes: 6 additions & 3 deletions tensorboardX/pytorch_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,9 +96,12 @@ def run_pass(name, trace):
graph = trace.graph()

torch._C._jit_pass_lint(graph)
result = getattr(torch._C, '_jit_pass_' + name)(graph)
if result is not None:
graph = result
try:
result = getattr(torch._C, '_jit_pass_' + name)(graph)
if result is not None:
graph = result
except AttributeError:
pass
torch._C._jit_pass_lint(graph)

if set_graph:
Expand Down
11 changes: 5 additions & 6 deletions tensorboardX/writer.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,20 +510,19 @@ def add_text(self, tag, text_string, global_step=None, walltime=None):
def add_onnx_graph(self, prototxt):
self.file_writer.add_onnx_graph(gg(prototxt))

# Supports both Caffe2 and PyTorch models
def add_graph(self, model, input_to_model=None, verbose=False, **kwargs):
# prohibit second call?
# no, let tensorboard handles it and show its warning message.
# no, let tensorboard handle it and show its warning message.
"""Add graph data to summary.
Args:
model (torch.nn.Module): model to draw.
input_to_model (torch.autograd.Variable): a variable or a tuple of variables to be fed.
input_to_model (torch.autograd.Variable): a variable or a tuple of
variables to be fed.
"""
try:
if hasattr(model, 'forward'):
# A valid PyTorch model should have a 'forward' method
_ = getattr(model, 'forward')
import torch
from distutils.version import LooseVersion
if LooseVersion(torch.__version__) >= LooseVersion("0.3.1"):
Expand All @@ -536,7 +535,7 @@ def add_graph(self, model, input_to_model=None, verbose=False, **kwargs):
print('add_graph() only supports PyTorch v0.2.')
return
self.file_writer.add_graph(graph(model, input_to_model, verbose))
except AttributeError:
else:
# Caffe2 models do not have the 'forward' method
if not self.caffe2_enabled:
# TODO (ml7): Remove when PyTorch 1.0 merges PyTorch and Caffe2
Expand Down

0 comments on commit 4093ead

Please sign in to comment.