Skip to content

Commit

Permalink
Switch from 'setattr' to decorator. (fossasia#392)
Browse files Browse the repository at this point in the history
Summary:
First pass at resolving fossasia#389.
Closes fossasia#392

Reviewed By: lvdmaaten

Differential Revision: D8543559

Pulled By: JackUrb

fbshipit-source-id: 6070599d5c9ed85c9ccc2b3dc34d6ca6582f5115
  • Loading branch information
malmaud authored and facebook-github-bot committed Jun 25, 2018
1 parent 0eda08f commit 2305c7c
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 20 deletions.
68 changes: 50 additions & 18 deletions py/visdom/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import time
import errno
import io
from functools import wraps
try:
import bs4
BS4_AVAILABLE = True
Expand Down Expand Up @@ -249,23 +250,41 @@ def _assert_opts(opts):
assert isstr(opts.get('title')), 'title should be a string'


def pytorch_wrap(fn):
def result(*args, **kwargs):
args = (a.cpu().detach().numpy() if type(a).__module__.startswith('torch') else a for a in args)
torch_types = []
try:
import torch
torch_types.append(torch.Tensor)
torch_types.append(torch.nn.Parameter)
except (ImportError, AttributeError):
pass

for k in kwargs:
if type(kwargs[k]).__module__.startswith('torch'):
kwargs[k] = kwargs[k].cpu().detach().numpy()

return fn(*args, **kwargs)
return result
def _torch_to_numpy(a):
if len(torch_types) > 0:
if isinstance(a, torch.autograd.Variable):
# For PyTorch < 0.4 comptability.
warnings.warn(
"Support for versions of PyTorch less than 0.4 is deprecated and "
"will eventually be removed.", DeprecationWarning)
a = a.data
for kind in torch_types:
if isinstance(a, kind):
# For PyTorch < 0.4 comptability, where non-Variable
# tensors do not have a 'detach' method. Will be removed.
if hasattr(a, 'detach'):
a = a.detach()
return a.cpu().numpy()
return a


def wrap_tensor_methods(cls, wrapper):
fns = ['_surface', 'bar', 'boxplot', 'surf', 'heatmap', 'histogram', 'svg',
'image', 'images', 'line', 'pie', 'scatter', 'stem', 'quiver', 'contour']
for key in [k for k in dir(cls) if k in fns]:
setattr(cls, key, wrapper(getattr(cls, key)))
def pytorch_wrap(f):
@wraps(f)
def wrapped_f(*args, **kwargs):
args = (_torch_to_numpy(arg) for arg in args)
kwargs = {k: _torch_to_numpy(v) for (k, v) in kwargs.items()}
return f(*args, **kwargs)

return wrapped_f


class Visdom(object):
Expand Down Expand Up @@ -299,11 +318,6 @@ def __init__(
# Flag to indicate whether to raise errors or suppress them
self.raise_exceptions = raise_exceptions
self.log_to_filename = log_to_filename
try:
import torch # noqa F401: we do use torch, just weirdly
wrap_tensor_methods(self, pytorch_wrap)
except ImportError:
pass

self._send({
'eid': env,
Expand Down Expand Up @@ -605,6 +619,7 @@ def properties(self, data, win=None, env=None, opts=None):
'opts': opts,
}, endpoint='events')

@pytorch_wrap
def svg(self, svgstr=None, svgfile=None, win=None, env=None, opts=None):
"""
This function draws an SVG object. It takes as input an SVG string or the
Expand Down Expand Up @@ -667,6 +682,7 @@ def matplot(self, plot, opts=None, env=None, win=None):
opts['width'] = 1.35 * int(math.ceil(float(width.group(1))))
return self.svg(svgstr=svg, opts=opts, env=env, win=win)

@pytorch_wrap
def image(self, img, win=None, env=None, opts=None):
"""
This function draws an img. It takes as input an `CxHxW` or `HxW` tensor
Expand Down Expand Up @@ -710,6 +726,7 @@ def image(self, img, win=None, env=None, opts=None):
'opts': opts,
})

@pytorch_wrap
def images(self, tensor, nrow=8, padding=2,
win=None, env=None, opts=None):
"""
Expand Down Expand Up @@ -757,6 +774,7 @@ def images(self, tensor, nrow=8, padding=2,

return self.image(grid, win, env, opts)

@pytorch_wrap
def audio(self, tensor=None, audiofile=None, win=None, env=None, opts=None):
"""
This function plays audio. It takes as input the filename of the audio
Expand Down Expand Up @@ -800,6 +818,7 @@ def audio(self, tensor=None, audiofile=None, win=None, env=None, opts=None):
opts['width'] = 330
return self.text(text=videodata, win=win, env=env, opts=opts)

@pytorch_wrap
def video(self, tensor=None, videofile=None, win=None, env=None, opts=None):
"""
This function plays a video. It takes as input the filename of the video
Expand Down Expand Up @@ -876,6 +895,7 @@ def update_window_opts(self, win, opts, env=None):
}
return self._send(data_to_send, endpoint='update')

@pytorch_wrap
def scatter(self, X, Y=None, win=None, env=None, opts=None, update=None,
name=None):
"""
Expand Down Expand Up @@ -1032,6 +1052,7 @@ def scatter(self, X, Y=None, win=None, env=None, opts=None, update=None,

return self._send(data_to_send, endpoint=endpoint)

@pytorch_wrap
def line(self, Y, X=None, win=None, env=None, opts=None, update=None,
name=None):
"""
Expand Down Expand Up @@ -1102,6 +1123,7 @@ def line(self, Y, X=None, win=None, env=None, opts=None, update=None,
return self.scatter(X=linedata, Y=labels, opts=opts, win=win, env=env,
update=update, name=name)

@pytorch_wrap
def heatmap(self, X, win=None, env=None, opts=None):
"""
This function draws a heatmap. It takes as input an `NxM` tensor `X`
Expand Down Expand Up @@ -1150,6 +1172,7 @@ def heatmap(self, X, win=None, env=None, opts=None):
'opts': opts,
})

@pytorch_wrap
def bar(self, X, Y=None, win=None, env=None, opts=None):
"""
This function draws a regular, stacked, or grouped bar plot. It takes as
Expand Down Expand Up @@ -1215,6 +1238,7 @@ def bar(self, X, Y=None, win=None, env=None, opts=None):
'opts': opts,
})

@pytorch_wrap
def histogram(self, X, win=None, env=None, opts=None):
"""
This function draws a histogram of the specified data. It takes as input
Expand Down Expand Up @@ -1246,6 +1270,7 @@ def histogram(self, X, win=None, env=None, opts=None):
env=env
)

@pytorch_wrap
def boxplot(self, X, win=None, env=None, opts=None):
"""
This function draws boxplots of the specified data. It takes as input
Expand Down Expand Up @@ -1290,6 +1315,7 @@ def boxplot(self, X, win=None, env=None, opts=None):
'opts': opts,
})

@pytorch_wrap
def _surface(self, X, stype, win=None, env=None, opts=None):
"""
This function draws a surface plot. It takes as input an `NxM` tensor
Expand Down Expand Up @@ -1330,6 +1356,7 @@ def _surface(self, X, stype, win=None, env=None, opts=None):
'opts': opts,
})

@pytorch_wrap
def surf(self, X, win=None, env=None, opts=None):
"""
This function draws a surface plot. It takes as input an `NxM` tensor
Expand All @@ -1344,6 +1371,7 @@ def surf(self, X, win=None, env=None, opts=None):

return self._surface(X=X, stype='surface', opts=opts, win=win, env=env)

@pytorch_wrap
def contour(self, X, win=None, env=None, opts=None):
"""
This function draws a contour plot. It takes as input an `NxM` tensor
Expand All @@ -1358,6 +1386,7 @@ def contour(self, X, win=None, env=None, opts=None):

return self._surface(X=X, stype='contour', opts=opts, win=win, env=env)

@pytorch_wrap
def quiver(self, X, Y, gridX=None, gridY=None,
win=None, env=None, opts=None):
"""
Expand Down Expand Up @@ -1441,6 +1470,7 @@ def quiver(self, X, Y, gridX=None, gridY=None,
# generate scatter plot:
return self.scatter(X=data, opts=opts, win=win, env=env)

@pytorch_wrap
def stem(self, X, Y=None, win=None, env=None, opts=None):
"""
This function draws a stem plot. It takes as input an `N` or `NxM`tensor
Expand Down Expand Up @@ -1488,6 +1518,7 @@ def stem(self, X, Y=None, win=None, env=None, opts=None):

return self.scatter(X=data, Y=labels, opts=opts, win=win, env=env)

@pytorch_wrap
def pie(self, X, win=None, env=None, opts=None):
"""
This function draws a pie chart based on the `N` tensor `X`.
Expand Down Expand Up @@ -1519,6 +1550,7 @@ def pie(self, X, win=None, env=None, opts=None):
'opts': opts,
})

@pytorch_wrap
def mesh(self, X, Y=None, win=None, env=None, opts=None):
"""
This function draws a mesh plot from a set of vertices defined in an
Expand Down
15 changes: 13 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,24 @@
from pkg_resources import get_distribution, DistributionNotFound


try:
import torch
if (torch.__version__ < "0.3.1"):
print(
"[visdom] WARNING: Visdom support for pytorch less than version "
"0.3.1 is unsupported. Visdom will still work for other purposes "
"though."
)
except Exception:
pass # User doesn't have torch


def get_dist(pkgname):
try:
return get_distribution(pkgname)
except DistributionNotFound:
return None


here = os.path.abspath(os.path.dirname(__file__))

with open(os.path.join(here, 'py/visdom/VERSION')) as version_file:
Expand All @@ -34,7 +45,7 @@ def get_dist(pkgname):
# Metadata
name='visdom',
version=version,
author='Allan Jabri, Jack Urbanek, Laurens van der Maaten',
author='Jack Urbanek, Allan Jabri, Laurens van der Maaten',
author_email='[email protected]',
url='https://github.com/facebookresearch/visdom',
description='A tool for visualizing live, rich data for Torch and Numpy',
Expand Down

0 comments on commit 2305c7c

Please sign in to comment.