From 6ba4badb6e46624874c6709102a288ea80611673 Mon Sep 17 00:00:00 2001 From: Orion Reblitz-Richardson Date: Thu, 25 Oct 2018 09:05:50 -0700 Subject: [PATCH] Fixes for Caffe2 and tests to switch over to lanpa/tensorboardX (#259) * Fix caffe2_graph to match latest PyTorch/Caffe2 master * Use unittest classes for tests * Fix some issues with Caffe2 writer and improve tests --- tensorboardX/caffe2_graph.py | 7 ++-- tensorboardX/writer.py | 10 +++--- tests/test_caffe2.py | 5 +++ tests/test_chainer_np.py | 46 ++++++++++++------------ tests/test_figure.py | 69 ++++++++++++++++++++---------------- tests/test_numpy.py | 35 ++++++++++-------- tests/test_pytorch_np.py | 57 +++++++++++++++-------------- tests/test_summary.py | 47 ++++++++++++------------ tests/test_summary_writer.py | 34 +++++++++--------- 9 files changed, 171 insertions(+), 139 deletions(-) diff --git a/tensorboardX/caffe2_graph.py b/tensorboardX/caffe2_graph.py index 341d1f7c..926e3c2a 100644 --- a/tensorboardX/caffe2_graph.py +++ b/tensorboardX/caffe2_graph.py @@ -316,7 +316,10 @@ def _tf_device(device_option): if device_option.device_type == caffe2_pb2.CPU: return "/cpu:*" if device_option.device_type == caffe2_pb2.CUDA: - return "/gpu:{}".format(device_option.cuda_gpu_id) + if device_option.HasField("device_id"): + return "/gpu:{}".format(device_option.device_id) + elif device_option.HasField("cuda_gpu_id"): + return "/gpu:{}".format(device_option.cuda_gpu_id) raise Exception("Unhandled device", device_option) @@ -665,8 +668,6 @@ def _operators_to_graph_def( Returns: current_graph: GraphDef representing the computation graph formed by the set of operators. - blob_name_tracker: (Filtered) list of blob names corresponding to input - and output nodes of the operators in the graph. ''' if blob_name_tracker is not None: blob_name_tracker.clear() diff --git a/tensorboardX/writer.py b/tensorboardX/writer.py index 85f237a1..1461c886 100644 --- a/tensorboardX/writer.py +++ b/tensorboardX/writer.py @@ -541,20 +541,22 @@ def add_graph(self, model, input_to_model=None, verbose=False, **kwargs): return from caffe2.proto import caffe2_pb2 from caffe2.python import core - from .caffe2_graph import model_to_graph, nets_to_graph, protos_to_graph + from .caffe2_graph import ( + model_to_graph_def, nets_to_graph_def, protos_to_graph_def + ) # notimporterror should be already handled when checking self.caffe2_enabled '''Write graph to the summary. Check model type and handle accordingly.''' if isinstance(model, list): if isinstance(model[0], core.Net): - current_graph, track_blob_names = nets_to_graph( + current_graph = nets_to_graph_def( model, **kwargs) elif isinstance(model[0], caffe2_pb2.NetDef): - current_graph, track_blob_names = protos_to_graph( + current_graph = protos_to_graph_def( model, **kwargs) # Handles cnn.CNNModelHelper, model_helper.ModelHelper else: - current_graph, track_blob_names = model_to_graph( + current_graph = model_to_graph_def( model, **kwargs) event = event_pb2.Event( graph_def=current_graph.SerializeToString()) diff --git a/tests/test_caffe2.py b/tests/test_caffe2.py index a6739cdd..a2ad37d3 100644 --- a/tests/test_caffe2.py +++ b/tests/test_caffe2.py @@ -3,6 +3,7 @@ from __future__ import print_function from __future__ import unicode_literals +from tensorboardX import SummaryWriter import unittest try: @@ -1694,6 +1695,8 @@ def test_simple_cnnmodel(self): model.net.RunAllOnGPU() model.param_init_net.RunAllOnGPU() model.AddGradientOperators([loss], skip=1) + with SummaryWriter(filename_suffix='.test') as writer: + writer.add_graph(model) blob_name_tracker = {} graph = tb.model_to_graph_def( model, @@ -1754,6 +1757,8 @@ def test_simple_model(self): model.net.RunAllOnGPU() model.param_init_net.RunAllOnGPU() model.AddGradientOperators([loss], skip=1) + with SummaryWriter(filename_suffix='.test') as writer: + writer.add_graph(model) blob_name_tracker = {} graph = tb.model_to_graph_def( model, diff --git a/tests/test_chainer_np.py b/tests/test_chainer_np.py index e3c8b00c..e9b131b9 100644 --- a/tests/test_chainer_np.py +++ b/tests/test_chainer_np.py @@ -1,3 +1,8 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + from tensorboardX import x2num, SummaryWriter try: import chainer @@ -6,36 +11,33 @@ print('Chainer is not installed, skipping test') chainer_installed = False import numpy as np +import unittest + + if chainer_installed: chainer.Variable tensors = [chainer.Variable(np.random.rand(3, 10, 10)), chainer.Variable(np.random.rand(1)), chainer.Variable(np.random.rand(1, 2, 3, 4, 5))] + class ChainerTest(unittest.TestCase): + def test_chainer_np(self): + for tensor in tensors: + # regular variable + assert isinstance(x2num.make_np(tensor), np.ndarray) -def test_chainer_np(): - if not chainer_installed: - return - for tensor in tensors: - # regular variable - assert isinstance(x2num.make_np(tensor), np.ndarray) - - # python primitive type - assert(isinstance(x2num.make_np(0), np.ndarray)) - assert(isinstance(x2num.make_np(0.1), np.ndarray)) + # python primitive type + assert(isinstance(x2num.make_np(0), np.ndarray)) + assert(isinstance(x2num.make_np(0.1), np.ndarray)) -def test_chainer_img(): - if not chainer_installed: - return - shapes = [(77, 3, 13, 7), (77, 1, 13, 7), (3, 13, 7), (1, 13, 7), (13, 7)] - for s in shapes: - x = chainer.Variable(np.random.random_sample(s)) - assert x2num.make_np(x, 'IMG').shape[2] == 3 + def test_chainer_img(self): + shapes = [(77, 3, 13, 7), (77, 1, 13, 7), (3, 13, 7), (1, 13, 7), (13, 7)] + for s in shapes: + x = chainer.Variable(np.random.random_sample(s)) + assert x2num.make_np(x, 'IMG').shape[2] == 3 -def test_chainer_write(): - if not chainer_installed: - return - with SummaryWriter() as w: - w.add_scalar('scalar', chainer.Variable(np.random.rand(1)), 0) + def test_chainer_write(self): + with SummaryWriter() as w: + w.add_scalar('scalar', chainer.Variable(np.random.rand(1)), 0) diff --git a/tests/test_figure.py b/tests/test_figure.py index ad7947eb..dbe5c74f 100644 --- a/tests/test_figure.py +++ b/tests/test_figure.py @@ -1,44 +1,51 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + import matplotlib.pyplot as plt -from tensorboardX import SummaryWriter +import unittest +from tensorboardX import SummaryWriter -def test_figure(): - writer = SummaryWriter() - figure, axes = plt.figure(), plt.gca() - circle1 = plt.Circle((0.2, 0.5), 0.2, color='r') - circle2 = plt.Circle((0.8, 0.5), 0.2, color='g') - axes.add_patch(circle1) - axes.add_patch(circle2) - plt.axis('scaled') - plt.tight_layout() +class FigureTest(unittest.TestCase): + def test_figure(self): + writer = SummaryWriter() - writer.add_figure("add_figure/figure", figure, 0, close=False) - assert plt.fignum_exists(figure.number) is True + figure, axes = plt.figure(), plt.gca() + circle1 = plt.Circle((0.2, 0.5), 0.2, color='r') + circle2 = plt.Circle((0.8, 0.5), 0.2, color='g') + axes.add_patch(circle1) + axes.add_patch(circle2) + plt.axis('scaled') + plt.tight_layout() - writer.add_figure("add_figure/figure", figure, 1) - assert plt.fignum_exists(figure.number) is False + writer.add_figure("add_figure/figure", figure, 0, close=False) + assert plt.fignum_exists(figure.number) is True - writer.close() + writer.add_figure("add_figure/figure", figure, 1) + assert plt.fignum_exists(figure.number) is False + writer.close() -def test_figure_list(): - writer = SummaryWriter() + def test_figure_list(self): + writer = SummaryWriter() - figures = [] - for i in range(5): - figure = plt.figure() - plt.plot([i * 1, i * 2, i * 3], label="Plot " + str(i)) - plt.xlabel("X") - plt.xlabel("Y") - plt.legend() - plt.tight_layout() - figures.append(figure) + figures = [] + for i in range(5): + figure = plt.figure() + plt.plot([i * 1, i * 2, i * 3], label="Plot " + str(i)) + plt.xlabel("X") + plt.xlabel("Y") + plt.legend() + plt.tight_layout() + figures.append(figure) - writer.add_figure("add_figure/figure_list", figures, 0, close=False) - assert all([plt.fignum_exists(figure.number) is True for figure in figures]) + writer.add_figure("add_figure/figure_list", figures, 0, close=False) + assert all([plt.fignum_exists(figure.number) is True for figure in figures]) - writer.add_figure("add_figure/figure_list", figures, 1) - assert all([plt.fignum_exists(figure.number) is False for figure in figures]) + writer.add_figure("add_figure/figure_list", figures, 1) + assert all([plt.fignum_exists(figure.number) is False for figure in figures]) - writer.close() + writer.close() diff --git a/tests/test_numpy.py b/tests/test_numpy.py index 186e6cb9..8f62bb7a 100644 --- a/tests/test_numpy.py +++ b/tests/test_numpy.py @@ -1,19 +1,26 @@ -from tensorboardX import x2num +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + import numpy as np +import unittest +from tensorboardX import x2num -def test_scalar(): - res = x2num.make_np(1.1) - assert isinstance(res, np.ndarray) and res.shape == (1,) - res = x2num.make_np(1000000000000000000000) - assert isinstance(res, np.ndarray) and res.shape == (1,) - res = x2num.make_np(np.float16(1.00000087)) - assert isinstance(res, np.ndarray) and res.shape == (1,) - res = x2num.make_np(np.float128(1.00008 + 9)) - assert isinstance(res, np.ndarray) and res.shape == (1,) - res = x2num.make_np(np.int64(100000000000)) - assert isinstance(res, np.ndarray) and res.shape == (1,) +class NumpyTest(unittest.TestCase): + def test_scalar(self): + res = x2num.make_np(1.1) + assert isinstance(res, np.ndarray) and res.shape == (1,) + res = x2num.make_np(1000000000000000000000) + assert isinstance(res, np.ndarray) and res.shape == (1,) + res = x2num.make_np(np.float16(1.00000087)) + assert isinstance(res, np.ndarray) and res.shape == (1,) + res = x2num.make_np(np.float128(1.00008 + 9)) + assert isinstance(res, np.ndarray) and res.shape == (1,) + res = x2num.make_np(np.int64(100000000000)) + assert isinstance(res, np.ndarray) and res.shape == (1,) -def test_make_grid(): - pass + def test_make_grid(self): + pass diff --git a/tests/test_pytorch_np.py b/tests/test_pytorch_np.py index 0af0ab99..fe504daf 100644 --- a/tests/test_pytorch_np.py +++ b/tests/test_pytorch_np.py @@ -1,37 +1,42 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function +from __future__ import unicode_literals + from tensorboardX import x2num, SummaryWriter import torch import numpy as np -tensors = [torch.rand(3, 10, 10), torch.rand(1), torch.rand(1, 2, 3, 4, 5)] - - -def test_pytorch_np(): - for tensor in tensors: - # regular tensor - assert isinstance(x2num.make_np(tensor), np.ndarray) +import unittest - # CUDA tensor - if torch.cuda.device_count() > 0: - assert isinstance(x2num.make_np(tensor.cuda()), np.ndarray) - # regular variable - assert isinstance(x2num.make_np(torch.autograd.Variable(tensor)), np.ndarray) +class PyTorchNumpyTest(unittest.TestCase): + def test_pytorch_np(self): + tensors = [torch.rand(3, 10, 10), torch.rand(1), torch.rand(1, 2, 3, 4, 5)] + for tensor in tensors: + # regular tensor + assert isinstance(x2num.make_np(tensor), np.ndarray) - # CUDA variable - if torch.cuda.device_count() > 0: - assert isinstance(x2num.make_np(torch.autograd.Variable(tensor).cuda()), np.ndarray) + # CUDA tensor + if torch.cuda.device_count() > 0: + assert isinstance(x2num.make_np(tensor.cuda()), np.ndarray) - # python primitive type - assert(isinstance(x2num.make_np(0), np.ndarray)) - assert(isinstance(x2num.make_np(0.1), np.ndarray)) + # regular variable + assert isinstance(x2num.make_np(torch.autograd.Variable(tensor)), np.ndarray) + # CUDA variable + if torch.cuda.device_count() > 0: + assert isinstance(x2num.make_np(torch.autograd.Variable(tensor).cuda()), np.ndarray) -def test_pytorch_img(): - shapes = [(77, 3, 13, 7), (77, 1, 13, 7), (3, 13, 7), (1, 13, 7), (13, 7)] - for s in shapes: - x = torch.Tensor(np.random.random_sample(s)) - assert x2num.make_np(x, 'IMG').shape[2] == 3 + # python primitive type + assert(isinstance(x2num.make_np(0), np.ndarray)) + assert(isinstance(x2num.make_np(0.1), np.ndarray)) + def test_pytorch_img(self): + shapes = [(77, 3, 13, 7), (77, 1, 13, 7), (3, 13, 7), (1, 13, 7), (13, 7)] + for s in shapes: + x = torch.Tensor(np.random.random_sample(s)) + assert x2num.make_np(x, 'IMG').shape[2] == 3 -def test_pytorch_write(): - with SummaryWriter() as w: - w.add_scalar('scalar', torch.autograd.Variable(torch.rand(1)), 0) + def test_pytorch_write(self): + with SummaryWriter() as w: + w.add_scalar('scalar', torch.autograd.Variable(torch.rand(1)), 0) diff --git a/tests/test_summary.py b/tests/test_summary.py index 1a9d902c..f9efb4f9 100644 --- a/tests/test_summary.py +++ b/tests/test_summary.py @@ -2,31 +2,32 @@ import numpy as np import pytest +import unittest -def test_uint8_image(): - ''' - Tests that uint8 image (pixel values in [0, 255]) is not changed - ''' - test_image = np.random.randint(0, 256, size=(3, 32, 32), dtype=np.uint8) - scale_factor = summary._calc_scale_factor(test_image) - assert scale_factor == 1, 'Values are already in [0, 255], scale factor should be 1' +class SummaryTest(unittest.TestCase): + def test_uint8_image(self): + ''' + Tests that uint8 image (pixel values in [0, 255]) is not changed + ''' + test_image = np.random.randint(0, 256, size=(3, 32, 32), dtype=np.uint8) + scale_factor = summary._calc_scale_factor(test_image) + assert scale_factor == 1, 'Values are already in [0, 255], scale factor should be 1' + def test_float32_image(self): + ''' + Tests that float32 image (pixel values in [0, 1]) are scaled correctly + to [0, 255] + ''' + test_image = np.random.rand(3, 32, 32).astype(np.float32) + scale_factor = summary._calc_scale_factor(test_image) + assert scale_factor == 255, 'Values are in [0, 1], scale factor should be 255' -def test_float32_image(): - ''' - Tests that float32 image (pixel values in [0, 1]) are scaled correctly - to [0, 255] - ''' - test_image = np.random.rand(3, 32, 32).astype(np.float32) - scale_factor = summary._calc_scale_factor(test_image) - assert scale_factor == 255, 'Values are in [0, 1], scale factor should be 255' + def test_list_input(self): + with pytest.raises(Exception) as e_info: + summary.histogram('dummy', [1,3,4,5,6], 'tensorflow') -def test_list_input(): - with pytest.raises(Exception) as e_info: - summary.histogram('dummy', [1,3,4,5,6], 'tensorflow') - -def test_empty_input(): - print('expect error here:') - with pytest.raises(Exception) as e_info: - summary.histogram('dummy', np.ndarray(0), 'tensorflow') \ No newline at end of file + def test_empty_input(self): + print('expect error here:') + with pytest.raises(Exception) as e_info: + summary.histogram('dummy', np.ndarray(0), 'tensorflow') diff --git a/tests/test_summary_writer.py b/tests/test_summary_writer.py index 3979c0da..d7043cce 100644 --- a/tests/test_summary_writer.py +++ b/tests/test_summary_writer.py @@ -1,22 +1,24 @@ from tensorboardX import SummaryWriter +import unittest -def test_summary_writer_ctx(): - # after using a SummaryWriter as a ctx it should be closed - with SummaryWriter(filename_suffix='.test') as writer: - writer.add_scalar('test', 1) - assert writer.file_writer is None +class SummaryWriterTest(unittest.TestCase): + def test_summary_writer_ctx(self): + # after using a SummaryWriter as a ctx it should be closed + with SummaryWriter(filename_suffix='.test') as writer: + writer.add_scalar('test', 1) + assert writer.file_writer is None -def test_summary_writer_close(): - # Opening and closing SummaryWriter a lot should not run into - # OSError: [Errno 24] Too many open files - passed = True - try: - for _ in range(2048): - writer = SummaryWriter() - writer.close() - except OSError: - passed = False + def test_summary_writer_close(self): + # Opening and closing SummaryWriter a lot should not run into + # OSError: [Errno 24] Too many open files + passed = True + try: + for _ in range(2048): + writer = SummaryWriter() + writer.close() + except OSError: + passed = False - assert passed + assert passed