Expose is_signed for dtype (#29511)
Summary:
Changelog:
- Expose is_signed for torch.dtype by modifying torch/csrc/Dtype.cpp
- Allow half, bfloat16 and bool to also been "known" by the isSignedType function
Pull Request resolved: https://github.com/pytorch/pytorch/pull/29511
Test Plan:
- Add tests in test/test_torch.py
Closes https://github.com/pytorch/pytorch/issues/29475
Differential Revision: D18439030
Pulled By: albanD
fbshipit-source-id: 4b1f9da70c1c8dfd0a5bc028b6936acd1c64af47
diff --git a/c10/core/ScalarType.h b/c10/core/ScalarType.h
index b9be182..4dfffc8 100644
--- a/c10/core/ScalarType.h
+++ b/c10/core/ScalarType.h
@@ -316,11 +316,11 @@
case ScalarType::name: \
return std::numeric_limits<ctype>::is_signed;
- switch (t) {
- AT_FORALL_SCALAR_TYPES_AND(Half, CASE_SIGNED)
- default:
- AT_ERROR("Unknown ScalarType");
- }
+ switch (toUnderlying(t)) {
+ AT_FORALL_SCALAR_TYPES_AND3(Half, Bool, BFloat16, CASE_SIGNED)
+ default:
+ AT_ERROR("Unknown ScalarType");
+ }
#undef CASE_SIGNED
}
diff --git a/test/test_torch.py b/test/test_torch.py
index e5c35a4..716519e 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -2427,6 +2427,14 @@
x = torch.Tensor([1, nan, 2])
self.assertEqual(torch.isnan(x), torch.ByteTensor([0, 1, 0]))
+ def test_dtype_is_signed(self):
+ for dtype in torch.testing.get_all_dtypes():
+ self.assertEqual(dtype.is_signed, torch.is_signed(torch.tensor(0, dtype=dtype)))
+
+ self.assertFalse(torch.quint8.is_signed)
+ self.assertTrue(torch.qint8.is_signed)
+ self.assertTrue(torch.qint32.is_signed)
+
def test_RNGState(self):
state = torch.get_rng_state()
stateCloned = state.clone()
diff --git a/torch/csrc/Dtype.cpp b/torch/csrc/Dtype.cpp
index c0a0847..a33f783 100644
--- a/torch/csrc/Dtype.cpp
+++ b/torch/csrc/Dtype.cpp
@@ -29,6 +29,15 @@
}
}
+PyObject *THPDtype_is_signed(THPDtype *self, PyObject *noargs)
+{
+ if (at::isSignedType(self->scalar_type)) {
+ Py_RETURN_TRUE;
+ } else {
+ Py_RETURN_FALSE;
+ }
+}
+
PyObject *THPDtype_reduce(THPDtype *self, PyObject *noargs)
{
/*
@@ -42,6 +51,7 @@
static struct PyGetSetDef THPDtype_properties[] = {
{"is_floating_point", (getter)THPDtype_is_floating_point, nullptr, nullptr, nullptr},
+ {"is_signed", (getter)THPDtype_is_signed, nullptr, nullptr, nullptr},
{nullptr}
};