Skip to content

Commit

Permalink
Remained changes of #43578 (#43921)
Browse files Browse the repository at this point in the history
Summary:
Not full pytorch/pytorch#43578 was merged. This PR is the remained part.

Pull Request resolved: pytorch/pytorch#43921

Reviewed By: ailzhang

Differential Revision: D23438504

Pulled By: mruberry

fbshipit-source-id: 9c5e26346dfc423b7a440b8a986420a27349090f
  • Loading branch information
xuhdev authored and facebook-github-bot committed Sep 1, 2020
1 parent 3c2f6d2 commit 69fbc70
Showing 1 changed file with 15 additions and 1 deletion.
16 changes: 15 additions & 1 deletion test/test_tensor_creation_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -716,7 +716,8 @@ def test_eye(self, device):

@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
@precisionOverride({torch.float: 1e-8, torch.double: 1e-10})
@dtypes(*([torch.float, torch.double] + torch.testing.get_all_complex_dtypes()))
@dtypes(*(torch.testing.get_all_fp_dtypes(include_half=False, include_bfloat16=False) +
torch.testing.get_all_complex_dtypes()))
def test_linspace_vs_numpy(self, device, dtype):
start = -0.0316082797944545745849609375 + (0.8888888888j if dtype.is_complex else 0)
end = .0315315723419189453125 + (0.444444444444j if dtype.is_complex else 0)
Expand Down Expand Up @@ -940,6 +941,10 @@ def test_linspace(self, device, dtype):
torch.tensor((2, 1, 0), device=device, dtype=dtype),
atol=0, rtol=0)

# Create non-complex tensor from complex numbers
if not dtype.is_complex:
self.assertRaises(RuntimeError, lambda: torch.linspace(1j, 2j, 3, device=device, dtype=dtype))

# Check for race condition (correctness when applied on a large tensor).
if dtype not in (torch.int8, torch.uint8, torch.int16, torch.half, torch.bfloat16):
y = torch.linspace(0, 999999 + (999999j if dtype.is_complex else 0),
Expand All @@ -956,6 +961,15 @@ def test_linspace(self, device, dtype):
y = torch.linspace(0, 3, 4, out=x.narrow(1, 1, 2), dtype=dtype)
self.assertEqual(x, torch.tensor(((0, 0, 1), (0, 2, 3)), device=device, dtype=dtype), atol=0, rtol=0)

def test_linspace_deduction(self, device):
# Test deduction from input parameters.
self.assertEqual(torch.linspace(1, 2, device=device).dtype, torch.float32)
self.assertEqual(torch.linspace(1., 2, device=device).dtype, torch.float32)
self.assertEqual(torch.linspace(1., -2., device=device).dtype, torch.float32)
# TODO: Need fix
with self.assertRaises(RuntimeError):
torch.linspace(1j, -2j, device=device)

# The implementation of linspace+logspace goes through a different path
# when the steps arg is equal to 0 or 1. For other values of `steps`
# they call specialized linspace (or logspace) kernels.
Expand Down

0 comments on commit 69fbc70

Please sign in to comment.