Skip to content

Commit

Permalink
Add better support for testing torch.allclose (facebookresearch#652)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch#652

Implemented a new `ClassyTestCase` with support for `assertTorchAllClose`. This prints in the assertion what the failure inputs are which makes writing tests easier.

Reviewed By: kazhang

Differential Revision: D24911128

fbshipit-source-id: b462b722a8d4e429f726924007d4c4855aabece1
  • Loading branch information
mannatsingh authored and facebook-github-bot committed Nov 16, 2020
1 parent 932aabc commit ad95b93
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 0 deletions.
16 changes: 16 additions & 0 deletions test/generic/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# LICENSE file in the root directory of this source tree.

import os
import unittest
from functools import wraps

import torch
Expand Down Expand Up @@ -298,3 +299,18 @@ def train(self, task: ClassyTask):
super().train(task)
except LimitedPhaseException:
pass


class ClassyTestCase(unittest.TestCase):
def assertTorchAllClose(
self, tensor_1, tensor_2, rtol=1e-5, atol=1e-8, equal_nan=False
):
for tensor in [tensor_1, tensor_2]:
if not isinstance(tensor, torch.Tensor):
raise AssertionError(
f"Expected tensor, not {tensor} of type {type(tensor)}"
)
if not torch.allclose(
tensor_1, tensor_2, rtol=rtol, atol=atol, equal_nan=equal_nan
):
raise AssertionError(f"{tensor_1} is not close to {tensor_2}")
37 changes: 37 additions & 0 deletions test/test_generic_utils_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

import unittest
from test.generic.utils import ClassyTestCase

import torch


class TestClassyTestCase(unittest.TestCase):
def test_assert_torch_all_close(self):
test_fixture = ClassyTestCase()

data = [1.1, 2.2]
tensor_1 = torch.Tensor(data)

# shouldn't raise an exception
tensor_2 = tensor_1
test_fixture.assertTorchAllClose(tensor_1, tensor_2)

# should fail because tensors are not close
tensor_2 = tensor_1 / 2
with self.assertRaises(AssertionError):
test_fixture.assertTorchAllClose(tensor_1, tensor_2)

# should fail because tensor_2 is not a tensor
tensor_2 = data
with self.assertRaises(AssertionError):
test_fixture.assertTorchAllClose(tensor_1, tensor_2)

# should fail because tensor_1 is not a tensor
tensor_1 = data
tensor_2 = torch.Tensor(data)
with self.assertRaises(AssertionError):
test_fixture.assertTorchAllClose(tensor_1, tensor_2)

0 comments on commit ad95b93

Please sign in to comment.