[dynamo][numpy] Add unsigned integer dtypes (#125717)
We should support these to whatever extent we can. They corresponding
`torch.uint<w>` types are defined, so I don't see an issue with
generating the various casting rules and allowing them to trace.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/125717
Approved by: https://github.com/lezcano
diff --git a/test/dynamo_expected_failures/TestHistogram.test_unsigned_monotonicity_check b/test/dynamo_expected_failures/TestHistogram.test_unsigned_monotonicity_check
deleted file mode 100644
index e69de29..0000000
--- a/test/dynamo_expected_failures/TestHistogram.test_unsigned_monotonicity_check
+++ /dev/null
diff --git a/test/dynamo_expected_failures/TestSortComplex.test_sort_real_type_in_H_type_out_F b/test/dynamo_expected_failures/TestSortComplex.test_sort_real_type_in_H_type_out_F
deleted file mode 100644
index e69de29..0000000
--- a/test/dynamo_expected_failures/TestSortComplex.test_sort_real_type_in_H_type_out_F
+++ /dev/null
diff --git a/test/test_numpy_interop.py b/test/test_numpy_interop.py
index 940938c..bf81701 100644
--- a/test/test_numpy_interop.py
+++ b/test/test_numpy_interop.py
@@ -476,13 +476,18 @@
self.assertTrue(r2.requires_grad)
@onlyCPU
- def test_parse_numpy_int(self, device):
+ @skipIfTorchDynamo()
+ def test_parse_numpy_int_overflow(self, device):
+ # assertRaises uses a try-except which dynamo has issues with
# Only concrete class can be given where "Type[number[_64Bit]]" is expected
self.assertRaisesRegex(
RuntimeError,
"(Overflow|an integer is required)",
lambda: torch.mean(torch.randn(1, 1), np.uint64(-1)),
) # type: ignore[call-overload]
+
+ @onlyCPU
+ def test_parse_numpy_int(self, device):
# https://github.com/pytorch/pytorch/issues/29252
for nptype in [np.int16, np.int8, np.uint8, np.int32, np.int64]:
scalar = 3
diff --git a/test/torch_np/numpy_tests/core/test_getlimits.py b/test/torch_np/numpy_tests/core/test_getlimits.py
index ab5b083..3be8bc2 100644
--- a/test/torch_np/numpy_tests/core/test_getlimits.py
+++ b/test/torch_np/numpy_tests/core/test_getlimits.py
@@ -8,14 +8,17 @@
# from numpy.core.getlimits import _discovered_machar, _float_ma
-from unittest import skipIf
+from unittest import expectedFailure as xfail, skipIf
import numpy
from pytest import raises as assert_raises
from torch.testing._internal.common_utils import (
+ instantiate_parametrized_tests,
+ parametrize,
run_tests,
+ subtest,
TEST_WITH_TORCHDYNAMO,
TestCase,
xpassIfTorchDynamo,
@@ -109,6 +112,7 @@
getattr(finfo(dt), attr)
+@instantiate_parametrized_tests
class TestIinfo(TestCase):
def test_basic(self):
dts = list(
@@ -129,11 +133,19 @@
with assert_raises((TypeError, ValueError)):
iinfo("f4")
- def test_unsigned_max(self):
- types = np.sctypes["uint"]
- for T in types:
- max_calculated = T(0) - T(1)
- assert_equal(iinfo(T).max, max_calculated)
+ @parametrize(
+ "T",
+ [
+ np.uint8,
+ # xfail: unsupported add (uint[16,32,64])
+ subtest(np.uint16, decorators=[xfail]),
+ subtest(np.uint32, decorators=[xfail]),
+ subtest(np.uint64, decorators=[xfail]),
+ ],
+ )
+ def test_unsigned_max(self, T):
+ max_calculated = T(0) - T(1)
+ assert_equal(iinfo(T).max, max_calculated)
class TestRepr(TestCase):
diff --git a/test/torch_np/numpy_tests/core/test_scalarmath.py b/test/torch_np/numpy_tests/core/test_scalarmath.py
index 8099ca8..d86595d 100644
--- a/test/torch_np/numpy_tests/core/test_scalarmath.py
+++ b/test/torch_np/numpy_tests/core/test_scalarmath.py
@@ -732,13 +732,16 @@
@instantiate_parametrized_tests
class TestBitShifts(TestCase):
- @parametrize("type_code", np.typecodes["Integer"] + "B")
+ @parametrize("type_code", np.typecodes["AllInteger"])
@parametrize("op", [operator.rshift, operator.lshift])
def test_shift_all_bits(self, type_code, op):
"""Shifts where the shift amount is the width of the type or wider"""
# gh-2449
dt = np.dtype(type_code)
nbits = dt.itemsize * 8
+ if dt in (np.dtype(np.uint64), np.dtype(np.uint32), np.dtype(np.uint16)):
+ raise SkipTest("NYI: bitshift uint64")
+
for val in [5, -5]:
for shift in [nbits, nbits + 4]:
val_scl = np.array(val).astype(dt)[()]
diff --git a/test/torch_np/test_dtype.py b/test/torch_np/test_dtype.py
index 42866ad..e288e54 100644
--- a/test/torch_np/test_dtype.py
+++ b/test/torch_np/test_dtype.py
@@ -18,7 +18,7 @@
dtype_names = [
"bool_",
*[f"int{w}" for w in [8, 16, 32, 64]],
- "uint8",
+ *[f"uint{w}" for w in [8, 16, 32, 64]],
*[f"float{w}" for w in [16, 32, 64]],
*[f"complex{w}" for w in [64, 128]],
]
diff --git a/torch/_numpy/_casting_dicts.py b/torch/_numpy/_casting_dicts.py
index 513e73e..b30ce7c 100644
--- a/torch/_numpy/_casting_dicts.py
+++ b/torch/_numpy/_casting_dicts.py
@@ -3,7 +3,7 @@
import torch
# These two dicts are autogenerated with autogen/gen_dtypes.py,
-# using numpy version 1.23.5.
+# using numpy version 1.24.3.
_can_cast_dict = {
"no": {
@@ -14,6 +14,9 @@
torch.complex64: False,
torch.complex128: False,
torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@@ -27,6 +30,9 @@
torch.complex64: False,
torch.complex128: False,
torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@@ -40,6 +46,9 @@
torch.complex64: False,
torch.complex128: False,
torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@@ -53,6 +62,9 @@
torch.complex64: True,
torch.complex128: False,
torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@@ -66,6 +78,9 @@
torch.complex64: False,
torch.complex128: True,
torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@@ -79,6 +94,57 @@
torch.complex64: False,
torch.complex128: False,
torch.uint8: True,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: False,
+ torch.int8: False,
+ torch.int16: False,
+ torch.int32: False,
+ torch.int64: False,
+ torch.bool: False,
+ },
+ torch.uint16: {
+ torch.float16: False,
+ torch.float32: False,
+ torch.float64: False,
+ torch.complex64: False,
+ torch.complex128: False,
+ torch.uint8: False,
+ torch.uint16: True,
+ torch.uint32: False,
+ torch.uint64: False,
+ torch.int8: False,
+ torch.int16: False,
+ torch.int32: False,
+ torch.int64: False,
+ torch.bool: False,
+ },
+ torch.uint32: {
+ torch.float16: False,
+ torch.float32: False,
+ torch.float64: False,
+ torch.complex64: False,
+ torch.complex128: False,
+ torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: True,
+ torch.uint64: False,
+ torch.int8: False,
+ torch.int16: False,
+ torch.int32: False,
+ torch.int64: False,
+ torch.bool: False,
+ },
+ torch.uint64: {
+ torch.float16: False,
+ torch.float32: False,
+ torch.float64: False,
+ torch.complex64: False,
+ torch.complex128: False,
+ torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: True,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@@ -92,6 +158,9 @@
torch.complex64: False,
torch.complex128: False,
torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: False,
torch.int8: True,
torch.int16: False,
torch.int32: False,
@@ -105,6 +174,9 @@
torch.complex64: False,
torch.complex128: False,
torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: False,
torch.int8: False,
torch.int16: True,
torch.int32: False,
@@ -118,6 +190,9 @@
torch.complex64: False,
torch.complex128: False,
torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: True,
@@ -131,6 +206,9 @@
torch.complex64: False,
torch.complex128: False,
torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@@ -144,6 +222,9 @@
torch.complex64: False,
torch.complex128: False,
torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@@ -159,6 +240,9 @@
torch.complex64: False,
torch.complex128: False,
torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@@ -172,6 +256,9 @@
torch.complex64: False,
torch.complex128: False,
torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@@ -185,6 +272,9 @@
torch.complex64: False,
torch.complex128: False,
torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@@ -198,6 +288,9 @@
torch.complex64: True,
torch.complex128: False,
torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@@ -211,6 +304,9 @@
torch.complex64: False,
torch.complex128: True,
torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@@ -224,6 +320,57 @@
torch.complex64: False,
torch.complex128: False,
torch.uint8: True,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: False,
+ torch.int8: False,
+ torch.int16: False,
+ torch.int32: False,
+ torch.int64: False,
+ torch.bool: False,
+ },
+ torch.uint16: {
+ torch.float16: False,
+ torch.float32: False,
+ torch.float64: False,
+ torch.complex64: False,
+ torch.complex128: False,
+ torch.uint8: False,
+ torch.uint16: True,
+ torch.uint32: False,
+ torch.uint64: False,
+ torch.int8: False,
+ torch.int16: False,
+ torch.int32: False,
+ torch.int64: False,
+ torch.bool: False,
+ },
+ torch.uint32: {
+ torch.float16: False,
+ torch.float32: False,
+ torch.float64: False,
+ torch.complex64: False,
+ torch.complex128: False,
+ torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: True,
+ torch.uint64: False,
+ torch.int8: False,
+ torch.int16: False,
+ torch.int32: False,
+ torch.int64: False,
+ torch.bool: False,
+ },
+ torch.uint64: {
+ torch.float16: False,
+ torch.float32: False,
+ torch.float64: False,
+ torch.complex64: False,
+ torch.complex128: False,
+ torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: True,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@@ -237,6 +384,9 @@
torch.complex64: False,
torch.complex128: False,
torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: False,
torch.int8: True,
torch.int16: False,
torch.int32: False,
@@ -250,6 +400,9 @@
torch.complex64: False,
torch.complex128: False,
torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: False,
torch.int8: False,
torch.int16: True,
torch.int32: False,
@@ -263,6 +416,9 @@
torch.complex64: False,
torch.complex128: False,
torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: True,
@@ -276,6 +432,9 @@
torch.complex64: False,
torch.complex128: False,
torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@@ -289,6 +448,9 @@
torch.complex64: False,
torch.complex128: False,
torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@@ -304,6 +466,9 @@
torch.complex64: True,
torch.complex128: True,
torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@@ -317,6 +482,9 @@
torch.complex64: True,
torch.complex128: True,
torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@@ -330,6 +498,9 @@
torch.complex64: False,
torch.complex128: True,
torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@@ -343,6 +514,9 @@
torch.complex64: True,
torch.complex128: True,
torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@@ -356,6 +530,9 @@
torch.complex64: False,
torch.complex128: True,
torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@@ -369,12 +546,63 @@
torch.complex64: True,
torch.complex128: True,
torch.uint8: True,
+ torch.uint16: True,
+ torch.uint32: True,
+ torch.uint64: True,
torch.int8: False,
torch.int16: True,
torch.int32: True,
torch.int64: True,
torch.bool: False,
},
+ torch.uint16: {
+ torch.float16: False,
+ torch.float32: True,
+ torch.float64: True,
+ torch.complex64: True,
+ torch.complex128: True,
+ torch.uint8: False,
+ torch.uint16: True,
+ torch.uint32: True,
+ torch.uint64: True,
+ torch.int8: False,
+ torch.int16: False,
+ torch.int32: True,
+ torch.int64: True,
+ torch.bool: False,
+ },
+ torch.uint32: {
+ torch.float16: False,
+ torch.float32: False,
+ torch.float64: True,
+ torch.complex64: False,
+ torch.complex128: True,
+ torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: True,
+ torch.uint64: True,
+ torch.int8: False,
+ torch.int16: False,
+ torch.int32: False,
+ torch.int64: True,
+ torch.bool: False,
+ },
+ torch.uint64: {
+ torch.float16: False,
+ torch.float32: False,
+ torch.float64: True,
+ torch.complex64: False,
+ torch.complex128: True,
+ torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: True,
+ torch.int8: False,
+ torch.int16: False,
+ torch.int32: False,
+ torch.int64: False,
+ torch.bool: False,
+ },
torch.int8: {
torch.float16: True,
torch.float32: True,
@@ -382,6 +610,9 @@
torch.complex64: True,
torch.complex128: True,
torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: False,
torch.int8: True,
torch.int16: True,
torch.int32: True,
@@ -395,6 +626,9 @@
torch.complex64: True,
torch.complex128: True,
torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: False,
torch.int8: False,
torch.int16: True,
torch.int32: True,
@@ -408,6 +642,9 @@
torch.complex64: False,
torch.complex128: True,
torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: True,
@@ -421,6 +658,9 @@
torch.complex64: False,
torch.complex128: True,
torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@@ -434,6 +674,9 @@
torch.complex64: True,
torch.complex128: True,
torch.uint8: True,
+ torch.uint16: True,
+ torch.uint32: True,
+ torch.uint64: True,
torch.int8: True,
torch.int16: True,
torch.int32: True,
@@ -449,6 +692,9 @@
torch.complex64: True,
torch.complex128: True,
torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@@ -462,6 +708,9 @@
torch.complex64: True,
torch.complex128: True,
torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@@ -475,6 +724,9 @@
torch.complex64: True,
torch.complex128: True,
torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@@ -488,6 +740,9 @@
torch.complex64: True,
torch.complex128: True,
torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@@ -501,6 +756,9 @@
torch.complex64: True,
torch.complex128: True,
torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: False,
torch.int8: False,
torch.int16: False,
torch.int32: False,
@@ -514,6 +772,57 @@
torch.complex64: True,
torch.complex128: True,
torch.uint8: True,
+ torch.uint16: True,
+ torch.uint32: True,
+ torch.uint64: True,
+ torch.int8: True,
+ torch.int16: True,
+ torch.int32: True,
+ torch.int64: True,
+ torch.bool: False,
+ },
+ torch.uint16: {
+ torch.float16: True,
+ torch.float32: True,
+ torch.float64: True,
+ torch.complex64: True,
+ torch.complex128: True,
+ torch.uint8: True,
+ torch.uint16: True,
+ torch.uint32: True,
+ torch.uint64: True,
+ torch.int8: True,
+ torch.int16: True,
+ torch.int32: True,
+ torch.int64: True,
+ torch.bool: False,
+ },
+ torch.uint32: {
+ torch.float16: True,
+ torch.float32: True,
+ torch.float64: True,
+ torch.complex64: True,
+ torch.complex128: True,
+ torch.uint8: True,
+ torch.uint16: True,
+ torch.uint32: True,
+ torch.uint64: True,
+ torch.int8: True,
+ torch.int16: True,
+ torch.int32: True,
+ torch.int64: True,
+ torch.bool: False,
+ },
+ torch.uint64: {
+ torch.float16: True,
+ torch.float32: True,
+ torch.float64: True,
+ torch.complex64: True,
+ torch.complex128: True,
+ torch.uint8: True,
+ torch.uint16: True,
+ torch.uint32: True,
+ torch.uint64: True,
torch.int8: True,
torch.int16: True,
torch.int32: True,
@@ -527,6 +836,9 @@
torch.complex64: True,
torch.complex128: True,
torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: False,
torch.int8: True,
torch.int16: True,
torch.int32: True,
@@ -540,6 +852,9 @@
torch.complex64: True,
torch.complex128: True,
torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: False,
torch.int8: True,
torch.int16: True,
torch.int32: True,
@@ -553,6 +868,9 @@
torch.complex64: True,
torch.complex128: True,
torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: False,
torch.int8: True,
torch.int16: True,
torch.int32: True,
@@ -566,6 +884,9 @@
torch.complex64: True,
torch.complex128: True,
torch.uint8: False,
+ torch.uint16: False,
+ torch.uint32: False,
+ torch.uint64: False,
torch.int8: True,
torch.int16: True,
torch.int32: True,
@@ -579,6 +900,9 @@
torch.complex64: True,
torch.complex128: True,
torch.uint8: True,
+ torch.uint16: True,
+ torch.uint32: True,
+ torch.uint64: True,
torch.int8: True,
torch.int16: True,
torch.int32: True,
@@ -594,6 +918,9 @@
torch.complex64: True,
torch.complex128: True,
torch.uint8: True,
+ torch.uint16: True,
+ torch.uint32: True,
+ torch.uint64: True,
torch.int8: True,
torch.int16: True,
torch.int32: True,
@@ -607,6 +934,9 @@
torch.complex64: True,
torch.complex128: True,
torch.uint8: True,
+ torch.uint16: True,
+ torch.uint32: True,
+ torch.uint64: True,
torch.int8: True,
torch.int16: True,
torch.int32: True,
@@ -620,6 +950,9 @@
torch.complex64: True,
torch.complex128: True,
torch.uint8: True,
+ torch.uint16: True,
+ torch.uint32: True,
+ torch.uint64: True,
torch.int8: True,
torch.int16: True,
torch.int32: True,
@@ -633,6 +966,9 @@
torch.complex64: True,
torch.complex128: True,
torch.uint8: True,
+ torch.uint16: True,
+ torch.uint32: True,
+ torch.uint64: True,
torch.int8: True,
torch.int16: True,
torch.int32: True,
@@ -646,6 +982,9 @@
torch.complex64: True,
torch.complex128: True,
torch.uint8: True,
+ torch.uint16: True,
+ torch.uint32: True,
+ torch.uint64: True,
torch.int8: True,
torch.int16: True,
torch.int32: True,
@@ -659,6 +998,57 @@
torch.complex64: True,
torch.complex128: True,
torch.uint8: True,
+ torch.uint16: True,
+ torch.uint32: True,
+ torch.uint64: True,
+ torch.int8: True,
+ torch.int16: True,
+ torch.int32: True,
+ torch.int64: True,
+ torch.bool: True,
+ },
+ torch.uint16: {
+ torch.float16: True,
+ torch.float32: True,
+ torch.float64: True,
+ torch.complex64: True,
+ torch.complex128: True,
+ torch.uint8: True,
+ torch.uint16: True,
+ torch.uint32: True,
+ torch.uint64: True,
+ torch.int8: True,
+ torch.int16: True,
+ torch.int32: True,
+ torch.int64: True,
+ torch.bool: True,
+ },
+ torch.uint32: {
+ torch.float16: True,
+ torch.float32: True,
+ torch.float64: True,
+ torch.complex64: True,
+ torch.complex128: True,
+ torch.uint8: True,
+ torch.uint16: True,
+ torch.uint32: True,
+ torch.uint64: True,
+ torch.int8: True,
+ torch.int16: True,
+ torch.int32: True,
+ torch.int64: True,
+ torch.bool: True,
+ },
+ torch.uint64: {
+ torch.float16: True,
+ torch.float32: True,
+ torch.float64: True,
+ torch.complex64: True,
+ torch.complex128: True,
+ torch.uint8: True,
+ torch.uint16: True,
+ torch.uint32: True,
+ torch.uint64: True,
torch.int8: True,
torch.int16: True,
torch.int32: True,
@@ -672,6 +1062,9 @@
torch.complex64: True,
torch.complex128: True,
torch.uint8: True,
+ torch.uint16: True,
+ torch.uint32: True,
+ torch.uint64: True,
torch.int8: True,
torch.int16: True,
torch.int32: True,
@@ -685,6 +1078,9 @@
torch.complex64: True,
torch.complex128: True,
torch.uint8: True,
+ torch.uint16: True,
+ torch.uint32: True,
+ torch.uint64: True,
torch.int8: True,
torch.int16: True,
torch.int32: True,
@@ -698,6 +1094,9 @@
torch.complex64: True,
torch.complex128: True,
torch.uint8: True,
+ torch.uint16: True,
+ torch.uint32: True,
+ torch.uint64: True,
torch.int8: True,
torch.int16: True,
torch.int32: True,
@@ -711,6 +1110,9 @@
torch.complex64: True,
torch.complex128: True,
torch.uint8: True,
+ torch.uint16: True,
+ torch.uint32: True,
+ torch.uint64: True,
torch.int8: True,
torch.int16: True,
torch.int32: True,
@@ -724,6 +1126,9 @@
torch.complex64: True,
torch.complex128: True,
torch.uint8: True,
+ torch.uint16: True,
+ torch.uint32: True,
+ torch.uint64: True,
torch.int8: True,
torch.int16: True,
torch.int32: True,
@@ -742,6 +1147,9 @@
torch.complex64: torch.complex64,
torch.complex128: torch.complex128,
torch.uint8: torch.float16,
+ torch.uint16: torch.float32,
+ torch.uint32: torch.float64,
+ torch.uint64: torch.float64,
torch.int8: torch.float16,
torch.int16: torch.float32,
torch.int32: torch.float64,
@@ -755,6 +1163,9 @@
torch.complex64: torch.complex64,
torch.complex128: torch.complex128,
torch.uint8: torch.float32,
+ torch.uint16: torch.float32,
+ torch.uint32: torch.float64,
+ torch.uint64: torch.float64,
torch.int8: torch.float32,
torch.int16: torch.float32,
torch.int32: torch.float64,
@@ -768,6 +1179,9 @@
torch.complex64: torch.complex128,
torch.complex128: torch.complex128,
torch.uint8: torch.float64,
+ torch.uint16: torch.float64,
+ torch.uint32: torch.float64,
+ torch.uint64: torch.float64,
torch.int8: torch.float64,
torch.int16: torch.float64,
torch.int32: torch.float64,
@@ -781,6 +1195,9 @@
torch.complex64: torch.complex64,
torch.complex128: torch.complex128,
torch.uint8: torch.complex64,
+ torch.uint16: torch.complex64,
+ torch.uint32: torch.complex128,
+ torch.uint64: torch.complex128,
torch.int8: torch.complex64,
torch.int16: torch.complex64,
torch.int32: torch.complex128,
@@ -794,6 +1211,9 @@
torch.complex64: torch.complex128,
torch.complex128: torch.complex128,
torch.uint8: torch.complex128,
+ torch.uint16: torch.complex128,
+ torch.uint32: torch.complex128,
+ torch.uint64: torch.complex128,
torch.int8: torch.complex128,
torch.int16: torch.complex128,
torch.int32: torch.complex128,
@@ -807,12 +1227,63 @@
torch.complex64: torch.complex64,
torch.complex128: torch.complex128,
torch.uint8: torch.uint8,
+ torch.uint16: torch.uint16,
+ torch.uint32: torch.uint32,
+ torch.uint64: torch.uint64,
torch.int8: torch.int16,
torch.int16: torch.int16,
torch.int32: torch.int32,
torch.int64: torch.int64,
torch.bool: torch.uint8,
},
+ torch.uint16: {
+ torch.float16: torch.float32,
+ torch.float32: torch.float32,
+ torch.float64: torch.float64,
+ torch.complex64: torch.complex64,
+ torch.complex128: torch.complex128,
+ torch.uint8: torch.uint16,
+ torch.uint16: torch.uint16,
+ torch.uint32: torch.uint32,
+ torch.uint64: torch.uint64,
+ torch.int8: torch.int32,
+ torch.int16: torch.int32,
+ torch.int32: torch.int32,
+ torch.int64: torch.int64,
+ torch.bool: torch.uint16,
+ },
+ torch.uint32: {
+ torch.float16: torch.float64,
+ torch.float32: torch.float64,
+ torch.float64: torch.float64,
+ torch.complex64: torch.complex128,
+ torch.complex128: torch.complex128,
+ torch.uint8: torch.uint32,
+ torch.uint16: torch.uint32,
+ torch.uint32: torch.uint32,
+ torch.uint64: torch.uint64,
+ torch.int8: torch.int64,
+ torch.int16: torch.int64,
+ torch.int32: torch.int64,
+ torch.int64: torch.int64,
+ torch.bool: torch.uint32,
+ },
+ torch.uint64: {
+ torch.float16: torch.float64,
+ torch.float32: torch.float64,
+ torch.float64: torch.float64,
+ torch.complex64: torch.complex128,
+ torch.complex128: torch.complex128,
+ torch.uint8: torch.uint64,
+ torch.uint16: torch.uint64,
+ torch.uint32: torch.uint64,
+ torch.uint64: torch.uint64,
+ torch.int8: torch.float64,
+ torch.int16: torch.float64,
+ torch.int32: torch.float64,
+ torch.int64: torch.float64,
+ torch.bool: torch.uint64,
+ },
torch.int8: {
torch.float16: torch.float16,
torch.float32: torch.float32,
@@ -820,6 +1291,9 @@
torch.complex64: torch.complex64,
torch.complex128: torch.complex128,
torch.uint8: torch.int16,
+ torch.uint16: torch.int32,
+ torch.uint32: torch.int64,
+ torch.uint64: torch.float64,
torch.int8: torch.int8,
torch.int16: torch.int16,
torch.int32: torch.int32,
@@ -833,6 +1307,9 @@
torch.complex64: torch.complex64,
torch.complex128: torch.complex128,
torch.uint8: torch.int16,
+ torch.uint16: torch.int32,
+ torch.uint32: torch.int64,
+ torch.uint64: torch.float64,
torch.int8: torch.int16,
torch.int16: torch.int16,
torch.int32: torch.int32,
@@ -846,6 +1323,9 @@
torch.complex64: torch.complex128,
torch.complex128: torch.complex128,
torch.uint8: torch.int32,
+ torch.uint16: torch.int32,
+ torch.uint32: torch.int64,
+ torch.uint64: torch.float64,
torch.int8: torch.int32,
torch.int16: torch.int32,
torch.int32: torch.int32,
@@ -859,6 +1339,9 @@
torch.complex64: torch.complex128,
torch.complex128: torch.complex128,
torch.uint8: torch.int64,
+ torch.uint16: torch.int64,
+ torch.uint32: torch.int64,
+ torch.uint64: torch.float64,
torch.int8: torch.int64,
torch.int16: torch.int64,
torch.int32: torch.int64,
@@ -872,6 +1355,9 @@
torch.complex64: torch.complex64,
torch.complex128: torch.complex128,
torch.uint8: torch.uint8,
+ torch.uint16: torch.uint16,
+ torch.uint32: torch.uint32,
+ torch.uint64: torch.uint64,
torch.int8: torch.int8,
torch.int16: torch.int16,
torch.int32: torch.int32,
diff --git a/torch/_numpy/_dtypes.py b/torch/_numpy/_dtypes.py
index f8b8f4f..27799ad 100644
--- a/torch/_numpy/_dtypes.py
+++ b/torch/_numpy/_dtypes.py
@@ -113,6 +113,24 @@
torch_dtype = torch.uint8
+class uint16(unsignedinteger):
+ name = "uint16"
+ typecode = "H"
+ torch_dtype = torch.uint16
+
+
+class uint32(signedinteger):
+ name = "uint32"
+ typecode = "I"
+ torch_dtype = torch.uint32
+
+
+class uint64(signedinteger):
+ name = "uint64"
+ typecode = "L"
+ torch_dtype = torch.uint64
+
+
# floating point
@@ -160,6 +178,7 @@
"byte": int8,
"short": int16,
"longlong": int64, # XXX: is this correct?
+ "ulonglong": uint64,
"ubyte": uint8,
"half": float16,
"single": float32,
@@ -180,7 +199,7 @@
# cf tests/core/test_scalar_methods.py
sctypes = {
"int": [int8, int16, int32, int64],
- "uint": [uint8],
+ "uint": [uint8, uint16, uint32, uint64],
"float": [float16, float32, float64],
"complex": [complex64, complex128],
"others": [bool_],