Skip to content

Commit

Permalink
Improve repr for torch.iinfo & torch.finfo (#40488)
Browse files Browse the repository at this point in the history
Summary:
- fix pytorch/pytorch#39991
- Include directly `min`/`max`/`eps`/`tiny` values in repr of `torch.iinfo` & `torch.finfo` for inspection
- Use `torch.float16` / `torch.int16` instead of uncorrespond names `Half` / `Short`
- The improved repr is shown just like:
```
>>> torch.iinfo(torch.int8)
iinfo(type=torch.int8, max=127, min=-128)
>>> torch.iinfo(torch.int16)
iinfo(type=torch.int16, max=32767, min=-32768)
>>> torch.iinfo(torch.int32)
iinfo(type=torch.int32, max=2.14748e+09, min=-2.14748e+09)
>>> torch.iinfo(torch.int64)
iinfo(type=torch.int64, max=9.22337e+18, min=-9.22337e+18)
>>> torch.finfo(torch.float16)
finfo(type=torch.float16, eps=0.000976563, max=65504, min=-65504, tiny=6.10352e-05)
>>> torch.finfo(torch.float32)
finfo(type=torch.float32, eps=1.19209e-07, max=3.40282e+38, min=-3.40282e+38, tiny=1.17549e-38)
>>> torch.finfo(torch.float64)
finfo(type=torch.float64, eps=2.22045e-16, max=1.79769e+308, min=-1.79769e+308, tiny=2.22507e-308)
```

Pull Request resolved: pytorch/pytorch#40488

Differential Revision: D22445301

Pulled By: mruberry

fbshipit-source-id: 552af9904c423006084b45d6c4adfb4b5689db54
  • Loading branch information
Kiyosora authored and facebook-github-bot committed Jul 10, 2020
1 parent cb6c352 commit 0651887
Show file tree
Hide file tree
Showing 6 changed files with 91 additions and 27 deletions.
1 change: 1 addition & 0 deletions docs/source/type_info.rst
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ eps float The smallest representable number such that ``1.0 + eps != 1
max float The largest representable number.
min float The smallest representable number (typically ``-max``).
tiny float The smallest positive representable number.
resolution float The approximate decimal resolution of this type, i.e., ``10**-precision``.
========= ===== ========================================

.. note::
Expand Down
30 changes: 24 additions & 6 deletions test/test_type_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,29 +14,30 @@
class TestDTypeInfo(TestCase):

def test_invalid_input(self):
for dtype in [torch.float32, torch.float64]:
for dtype in [torch.float16, torch.float32, torch.float64, torch.bfloat16, torch.complex64, torch.complex128, torch.bool]:
with self.assertRaises(TypeError):
_ = torch.iinfo(dtype)

for dtype in [torch.int64, torch.int32, torch.int16, torch.uint8]:
for dtype in [torch.int64, torch.int32, torch.int16, torch.int8, torch.uint8, torch.bool]:
with self.assertRaises(TypeError):
_ = torch.finfo(dtype)

@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_iinfo(self):
for dtype in [torch.int64, torch.int32, torch.int16, torch.uint8]:
for dtype in [torch.int64, torch.int32, torch.int16, torch.int8, torch.uint8]:
x = torch.zeros((2, 2), dtype=dtype)
xinfo = torch.iinfo(x.dtype)
xn = x.cpu().numpy()
xninfo = np.iinfo(xn.dtype)
self.assertEqual(xinfo.bits, xninfo.bits)
self.assertEqual(xinfo.max, xninfo.max)
self.assertEqual(xinfo.min, xninfo.min)
self.assertEqual(xinfo.dtype, xninfo.dtype)

@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_finfo(self):
initial_default_type = torch.get_default_dtype()
for dtype in [torch.float32, torch.float64]:
for dtype in [torch.float16, torch.float32, torch.float64, torch.complex64, torch.complex128]:
x = torch.zeros((2, 2), dtype=dtype)
xinfo = torch.finfo(x.dtype)
xn = x.cpu().numpy()
Expand All @@ -46,8 +47,25 @@ def test_finfo(self):
self.assertEqual(xinfo.min, xninfo.min)
self.assertEqual(xinfo.eps, xninfo.eps)
self.assertEqual(xinfo.tiny, xninfo.tiny)
torch.set_default_dtype(dtype)
self.assertEqual(torch.finfo(dtype), torch.finfo())
self.assertEqual(xinfo.resolution, xninfo.resolution)
self.assertEqual(xinfo.dtype, xninfo.dtype)
if not dtype.is_complex:
torch.set_default_dtype(dtype)
self.assertEqual(torch.finfo(dtype), torch.finfo())

# Special test case for BFloat16 type
x = torch.zeros((2, 2), dtype=torch.bfloat16)
xinfo = torch.finfo(x.dtype)
self.assertEqual(xinfo.bits, 16)
self.assertEqual(xinfo.max, 3.38953e+38)
self.assertEqual(xinfo.min, -3.38953e+38)
self.assertEqual(xinfo.eps, 0.0078125)
self.assertEqual(xinfo.tiny, 1.17549e-38)
self.assertEqual(xinfo.resolution, 0.01)
self.assertEqual(xinfo.dtype, "bfloat16")
torch.set_default_dtype(x.dtype)
self.assertEqual(torch.finfo(x.dtype), torch.finfo())

# Restore the default type to ensure that the test has no side effect
torch.set_default_dtype(initial_default_type)

Expand Down
3 changes: 3 additions & 0 deletions torch/_C/__init__.pyi.in
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class iinfo:
bits: _int
min: _int
max: _int
dtype: str

def __init__(self, dtype: _dtype) -> None: ...

Expand All @@ -68,6 +69,8 @@ class finfo:
max: _float
eps: _float
tiny: _float
resolution: _float
dtype: str

@overload
def __init__(self, dtype: _dtype) -> None: ...
Expand Down
80 changes: 60 additions & 20 deletions torch/csrc/TypeInfo.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include <torch/csrc/utils/python_arg_parser.h>
#include <torch/csrc/utils/python_numbers.h>
#include <torch/csrc/utils/python_strings.h>
#include <torch/csrc/utils/tensor_dtypes.h>

#include <c10/util/Exception.h>

Expand All @@ -20,7 +21,7 @@ PyObject* THPFInfo_New(const at::ScalarType& type) {
if (!self)
throw python_error();
auto self_ = reinterpret_cast<THPDTypeInfo*>(self.get());
self_->type = type;
self_->type = c10::toValueType(type);
return self.release();
}

Expand All @@ -34,18 +35,6 @@ PyObject* THPIInfo_New(const at::ScalarType& type) {
return self.release();
}

PyObject* THPFInfo_str(THPFInfo* self) {
std::ostringstream oss;
oss << "finfo(type=" << self->type << ")";
return THPUtils_packString(oss.str().c_str());
}

PyObject* THPIInfo_str(THPIInfo* self) {
std::ostringstream oss;
oss << "iinfo(type=" << self->type << ")";
return THPUtils_packString(oss.str().c_str());
}

PyObject* THPFInfo_pynew(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
HANDLE_TH_ERRORS
static torch::PythonArgParser parser({
Expand All @@ -63,7 +52,7 @@ PyObject* THPFInfo_pynew(PyTypeObject* type, PyObject* args, PyObject* kwargs) {
AT_ASSERT(at::isFloatingType(scalar_type));
} else {
scalar_type = r.scalartype(0);
if (!at::isFloatingType(scalar_type)) {
if (!at::isFloatingType(scalar_type) && !at::isComplexType(scalar_type)) {
return PyErr_Format(
PyExc_TypeError,
"torch.finfo() requires a floating point input type. Use torch.iinfo to handle '%s'",
Expand Down Expand Up @@ -123,7 +112,7 @@ static PyObject* THPDTypeInfo_bits(THPDTypeInfo* self, void*) {
}

static PyObject* THPFInfo_eps(THPFInfo* self, void*) {
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(at::kHalf,
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::kHalf, at::ScalarType::BFloat16,
self->type, "epsilon", [] {
return PyFloat_FromDouble(
std::numeric_limits<
Expand All @@ -132,20 +121,20 @@ static PyObject* THPFInfo_eps(THPFInfo* self, void*) {
}

static PyObject* THPFInfo_max(THPFInfo* self, void*) {
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(at::kHalf, self->type, "max", [] {
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::kHalf, at::ScalarType::BFloat16, self->type, "max", [] {
return PyFloat_FromDouble(
std::numeric_limits<at::scalar_value_type<scalar_t>::type>::max());
});
}

static PyObject* THPFInfo_min(THPFInfo* self, void*) {
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(at::kHalf, self->type, "min", [] {
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::kHalf, at::ScalarType::BFloat16, self->type, "lowest", [] {
return PyFloat_FromDouble(
std::numeric_limits<at::scalar_value_type<scalar_t>::type>::lowest());
});
}

static PyObject* THPIInfo_max(THPFInfo* self, void*) {
static PyObject* THPIInfo_max(THPIInfo* self, void*) {
if (at::isIntegralType(self->type, /*includeBool=*/false)) {
return AT_DISPATCH_INTEGRAL_TYPES(self->type, "max", [] {
return THPUtils_packInt64(std::numeric_limits<scalar_t>::max());
Expand All @@ -157,7 +146,7 @@ static PyObject* THPIInfo_max(THPFInfo* self, void*) {
});
}

static PyObject* THPIInfo_min(THPFInfo* self, void*) {
static PyObject* THPIInfo_min(THPIInfo* self, void*) {
if (at::isIntegralType(self->type, /*includeBool=*/false)) {
return AT_DISPATCH_INTEGRAL_TYPES(self->type, "min", [] {
return THPUtils_packInt64(std::numeric_limits<scalar_t>::lowest());
Expand All @@ -169,19 +158,69 @@ static PyObject* THPIInfo_min(THPFInfo* self, void*) {
});
}

static PyObject* THPIInfo_dtype(THPIInfo* self, void*) {
std::string primary_name, legacy_name;
std::tie(primary_name, legacy_name) = torch::utils::getDtypeNames(self->type);
return AT_DISPATCH_INTEGRAL_TYPES(self->type, "dtype", [primary_name] {
return PyUnicode_FromString((char*)primary_name.data());
});
}

static PyObject* THPFInfo_tiny(THPFInfo* self, void*) {
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND1(at::kHalf, self->type, "min", [] {
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::kHalf, at::ScalarType::BFloat16, self->type, "min", [] {
return PyFloat_FromDouble(
std::numeric_limits<at::scalar_value_type<scalar_t>::type>::min());
});
}

static PyObject* THPFInfo_resolution(THPFInfo* self, void*) {
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::kHalf, at::ScalarType::BFloat16, self->type, "digits10", [] {
return PyFloat_FromDouble(
std::pow(10, -std::numeric_limits<at::scalar_value_type<scalar_t>::type>::digits10));
});
}

static PyObject* THPFInfo_dtype(THPFInfo* self, void*) {
std::string primary_name, legacy_name;
std::tie(primary_name, legacy_name) = torch::utils::getDtypeNames(self->type);
return AT_DISPATCH_FLOATING_AND_COMPLEX_TYPES_AND2(at::kHalf, at::ScalarType::BFloat16, self->type, "dtype", [primary_name] {
return PyUnicode_FromString((char*)primary_name.data());
});
}

PyObject* THPFInfo_str(THPFInfo* self) {
std::ostringstream oss;
oss << "finfo(resolution=" << PyFloat_AsDouble(THPFInfo_resolution(self, nullptr));
oss << ", min=" << PyFloat_AsDouble(THPFInfo_min(self, nullptr));
oss << ", max=" << PyFloat_AsDouble(THPFInfo_max(self, nullptr));
oss << ", eps=" << PyFloat_AsDouble(THPFInfo_eps(self, nullptr));
oss << ", tiny=" << PyFloat_AsDouble(THPFInfo_tiny(self, nullptr));
oss << ", dtype=" << PyUnicode_AsUTF8(THPFInfo_dtype(self, nullptr)) << ")";

return THPUtils_packString(oss.str().c_str());
}

PyObject* THPIInfo_str(THPIInfo* self) {
auto type = self->type;
std::string primary_name, legacy_name;
std::tie(primary_name, legacy_name) = torch::utils::getDtypeNames(type);
std::ostringstream oss;

oss << "iinfo(min=" << PyFloat_AsDouble(THPIInfo_min(self, nullptr));
oss << ", max=" << PyFloat_AsDouble(THPIInfo_max(self, nullptr));
oss << ", dtype=" << PyUnicode_AsUTF8(THPIInfo_dtype(self, nullptr)) << ")";

return THPUtils_packString(oss.str().c_str());
}

static struct PyGetSetDef THPFInfo_properties[] = {
{"bits", (getter)THPDTypeInfo_bits, nullptr, nullptr, nullptr},
{"eps", (getter)THPFInfo_eps, nullptr, nullptr, nullptr},
{"max", (getter)THPFInfo_max, nullptr, nullptr, nullptr},
{"min", (getter)THPFInfo_min, nullptr, nullptr, nullptr},
{"tiny", (getter)THPFInfo_tiny, nullptr, nullptr, nullptr},
{"resolution", (getter)THPFInfo_resolution, nullptr, nullptr, nullptr},
{"dtype", (getter)THPFInfo_dtype, nullptr, nullptr, nullptr},
{nullptr}};

static PyMethodDef THPFInfo_methods[] = {
Expand Down Expand Up @@ -232,6 +271,7 @@ static struct PyGetSetDef THPIInfo_properties[] = {
{"bits", (getter)THPDTypeInfo_bits, nullptr, nullptr, nullptr},
{"max", (getter)THPIInfo_max, nullptr, nullptr, nullptr},
{"min", (getter)THPIInfo_min, nullptr, nullptr, nullptr},
{"dtype", (getter)THPIInfo_dtype, nullptr, nullptr, nullptr},
{nullptr}};

static PyMethodDef THPIInfo_methods[] = {
Expand Down
2 changes: 1 addition & 1 deletion torch/csrc/utils/tensor_dtypes.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
namespace torch {
namespace utils {

static std::pair<std::string, std::string> getDtypeNames(
std::pair<std::string, std::string> getDtypeNames(
at::ScalarType scalarType) {
switch (scalarType) {
case at::ScalarType::Byte:
Expand Down
2 changes: 2 additions & 0 deletions torch/csrc/utils/tensor_dtypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

namespace torch { namespace utils {

std::pair<std::string, std::string> getDtypeNames(at::ScalarType scalarType);

void initializeDtypes();

}} // namespace torch::utils

0 comments on commit 0651887

Please sign in to comment.