Improve uint{16,32,64} dlpack/numpy compatibility (#116808)
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/116808
Approved by: https://github.com/malfet, https://github.com/albanD
diff --git a/aten/src/ATen/DLConvertor.cpp b/aten/src/ATen/DLConvertor.cpp
index 8f2cac8..8a74e1c 100644
--- a/aten/src/ATen/DLConvertor.cpp
+++ b/aten/src/ATen/DLConvertor.cpp
@@ -148,6 +148,15 @@
case 8:
stype = ScalarType::Byte;
break;
+ case 16:
+ stype = ScalarType::UInt16;
+ break;
+ case 32:
+ stype = ScalarType::UInt32;
+ break;
+ case 64:
+ stype = ScalarType::UInt64;
+ break;
default:
TORCH_CHECK(
false, "Unsupported kUInt bits " + c10::to_string(dtype.bits));
diff --git a/test/test_dlpack.py b/test/test_dlpack.py
index 35fe8ad..87a4657 100644
--- a/test/test_dlpack.py
+++ b/test/test_dlpack.py
@@ -15,7 +15,7 @@
@skipMeta
@onlyNativeDeviceTypes
- @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
+ @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64))
def test_dlpack_capsule_conversion(self, device, dtype):
x = make_tensor((5,), dtype=dtype, device=device)
z = from_dlpack(to_dlpack(x))
@@ -23,7 +23,7 @@
@skipMeta
@onlyNativeDeviceTypes
- @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
+ @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64))
def test_dlpack_protocol_conversion(self, device, dtype):
x = make_tensor((5,), dtype=dtype, device=device)
z = from_dlpack(x)
@@ -62,7 +62,7 @@
@skipMeta
@onlyNativeDeviceTypes
- @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
+ @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64))
def test_from_dlpack(self, device, dtype):
x = make_tensor((5,), dtype=dtype, device=device)
y = torch.from_dlpack(x)
@@ -70,7 +70,7 @@
@skipMeta
@onlyNativeDeviceTypes
- @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
+ @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64))
def test_from_dlpack_noncontinguous(self, device, dtype):
x = make_tensor((25,), dtype=dtype, device=device).reshape(5, 5)
@@ -113,7 +113,7 @@
@skipMeta
@onlyNativeDeviceTypes
- @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
+ @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64))
def test_from_dlpack_dtype(self, device, dtype):
x = make_tensor((5,), dtype=dtype, device=device)
y = torch.from_dlpack(x)
diff --git a/test/test_numpy_interop.py b/test/test_numpy_interop.py
index a0b9561..ced9192 100644
--- a/test/test_numpy_interop.py
+++ b/test/test_numpy_interop.py
@@ -220,7 +220,7 @@
self.assertEqual(tensor_from_array2[i], array2[i])
# Test unsupported type
- array = np.array([1, 2, 3, 4], dtype=np.uint16)
+ array = np.array(['foo', 'bar'], dtype=np.dtype(np.str_))
with self.assertRaises(TypeError):
tensor_from_array = torch.from_numpy(array)
@@ -417,7 +417,7 @@
@onlyCPU
def test_parse_numpy_int(self, device):
# Only concrete class can be given where "Type[number[_64Bit]]" is expected
- self.assertRaisesRegex(RuntimeError, "Overflow",
+ self.assertRaisesRegex(RuntimeError, "(Overflow|an integer is required)",
lambda: torch.mean(torch.randn(1, 1), np.uint64(-1))) # type: ignore[call-overload]
# https://github.com/pytorch/pytorch/issues/29252
for nptype in [np.int16, np.int8, np.uint8, np.int32, np.int64]:
diff --git a/test/test_reductions.py b/test/test_reductions.py
index 45bff79..4295630 100644
--- a/test/test_reductions.py
+++ b/test/test_reductions.py
@@ -1666,41 +1666,43 @@
def is_integral(dtype):
return dtype in integral_types()
+ exact_dtype = True
# On Windows CI, the current version of `numpy` promotes all lower integers
# dtypes to int32 while `torch` promotes them to int64. Hence we skip on checking
# the exact dtype.
# Reference : https://dr.pytorch.org/api/view-log-full?build_id=122051580
# PR : https://github.com/pytorch/pytorch/pull/38628#issuecomment-655905370
- exact_dtype = False if (IS_WINDOWS and is_integral(dtype)) else True
-
+ if IS_WINDOWS and is_integral(dtype):
+ exact_dtype = False
+ # For uint8, numpy promotes to uint64 while torch promotes to int64.
+ # So we must skip this as well.
if dtype == torch.uint8:
- with self.assertRaises(TypeError):
- self._test_reduction_function_with_numpy(torch_fn, np_fn, device, dtype, with_extremal=with_extremal)
+ exact_dtype = False
+
+ # TODO: Investigate why the output is not close to numpy.
+ if dtype == torch.float16:
+ atol = 0.4
+ rtol = 1e-2
+ elif dtype == torch.float32:
+ atol = 7e-05
+ rtol = 3e-06
else:
- # TODO: Investigate why the output is not close to numpy.
- if dtype == torch.float16:
- atol = 0.4
- rtol = 1e-2
- elif dtype == torch.float32:
- atol = 7e-05
- rtol = 3e-06
- else:
- # Default values
- atol = None
- rtol = None
- self._test_reduction_function_with_numpy(torch_fn, np_fn, device, dtype,
- atol=atol, rtol=rtol, exact_dtype=exact_dtype,
- with_keepdim=with_keepdim, with_extremal=with_extremal)
+ # Default values
+ atol = None
+ rtol = None
+ self._test_reduction_function_with_numpy(torch_fn, np_fn, device, dtype,
+ atol=atol, rtol=rtol, exact_dtype=exact_dtype,
+ with_keepdim=with_keepdim, with_extremal=with_extremal)
@onlyNativeDeviceTypes
- @dtypes(*all_types_and(torch.half))
+ @dtypes(*set(all_types_and(torch.half)) - {torch.uint8})
def test_sum_vs_numpy(self, device, dtype):
self._test_sum_reduction_vs_numpy(torch.sum, np.sum, device, dtype)
self._test_sum_reduction_vs_numpy(torch.sum, np.sum, device, dtype, with_extremal=True)
self._test_sum_reduction_vs_numpy(torch.sum, np.sum, device, dtype, with_keepdim=True)
@onlyNativeDeviceTypes
- @dtypes(*all_types_and(torch.half))
+ @dtypes(*set(all_types_and(torch.half)) - {torch.uint8})
def test_nansum_vs_numpy(self, device, dtype):
self._test_sum_reduction_vs_numpy(torch.nansum, np.nansum, device, dtype)
self._test_sum_reduction_vs_numpy(torch.nansum, np.nansum, device, dtype, with_extremal=True)
@@ -3528,6 +3530,10 @@
# Workaround https://github.com/pytorch/pytorch/issues/66556
expected = np.asarray(expected) # transform numpy scalars to numpy.ndarray instances
+ # Numpy differs, producing uint32 on Windows
+ if expected.dtype in [np.uint64, np.uint32]:
+ exact_dtype = False
+
msg = ("Failed to produce expected results! Input tensor was"
f" {t}, torch result is {actual}, and reference result is"
f" {expected}.") if t.numel() < 10 else None
diff --git a/test/test_tensor_creation_ops.py b/test/test_tensor_creation_ops.py
index bfc7e51..1a3adb4 100644
--- a/test/test_tensor_creation_ops.py
+++ b/test/test_tensor_creation_ops.py
@@ -543,10 +543,11 @@
self.assertEqual(y, expected_y)
self.assertEqual(x, expected_x)
- @dtypes(*all_types_and_complex())
+ @dtypes(*all_types_and_complex(), torch.uint16, torch.uint32, torch.uint64)
def test_cat_out_fast_path_dim0_dim1(self, device, dtype):
+ int_types = integral_types_and(torch.uint16, torch.uint32, torch.uint64)
x = torch.zeros((0), device=device, dtype=dtype)
- if dtype in integral_types():
+ if dtype in int_types:
y = torch.randint(low=0, high=100, size=(4, 6), device=device, dtype=dtype)
else:
y = torch.randn((4, 6), device=device, dtype=dtype)
@@ -575,7 +576,7 @@
self.assertEqual(b_fastcat, expected_b)
# Finally, we need to make sure backward is not broken
# Integral types will not have grad
- if dtype not in integral_types():
+ if dtype not in int_types:
a = torch.randn((4, 3), device=device, dtype=dtype, requires_grad=True)
b = torch.randn((2, 3), device=device, dtype=dtype, requires_grad=True)
c = torch.randn((5, 3), device=device, dtype=dtype, requires_grad=True)
@@ -1046,6 +1047,8 @@
self._float_to_int_conversion_helper(vals, device, dtype, refs)
# Note: CUDA will fail this test on most dtypes, often dramatically.
+ # NB: torch.uint16, torch.uint32, torch.uint64 excluded as this
+ # nondeterministically fails, warning "invalid value encountered in cast"
@onlyCPU
@dtypes(torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
def test_float_to_int_conversion_nonfinite(self, device, dtype):
diff --git a/test/test_type_promotion.py b/test/test_type_promotion.py
index ceef13b..dd96bd2 100644
--- a/test/test_type_promotion.py
+++ b/test/test_type_promotion.py
@@ -941,8 +941,10 @@
@unittest.skipIf(not TEST_NUMPY, "NumPy not found")
@float_double_default_dtype
@onlyCPU
- @dtypes(*list(itertools.product(set(numpy_to_torch_dtype_dict.values()),
- set(numpy_to_torch_dtype_dict.values()))))
+ # NB: skip uint16,32,64 as PyTorch doesn't implement promotion for them
+ @dtypes(*list(itertools.product(
+ set(numpy_to_torch_dtype_dict.values()) - {torch.uint16, torch.uint32, torch.uint64},
+ set(numpy_to_torch_dtype_dict.values()) - {torch.uint16, torch.uint32, torch.uint64})))
def test_numpy_array_binary_ufunc_promotion(self, device, dtypes):
import operator
np_type = torch_to_numpy_dtype_dict[dtypes[0]]
diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py
index b684107..d32cf9c 100644
--- a/torch/_dynamo/utils.py
+++ b/torch/_dynamo/utils.py
@@ -423,6 +423,9 @@
return True
except TypeError:
return False
+ # cannot hash writable memoryview object
+ except ValueError:
+ return False
def nothing(*args, **kwargs):
diff --git a/torch/csrc/utils/tensor_numpy.cpp b/torch/csrc/utils/tensor_numpy.cpp
index 95147c7..ed36d96 100644
--- a/torch/csrc/utils/tensor_numpy.cpp
+++ b/torch/csrc/utils/tensor_numpy.cpp
@@ -295,6 +295,12 @@
return NPY_INT8;
case kByte:
return NPY_UINT8;
+ case kUInt16:
+ return NPY_UINT16;
+ case kUInt32:
+ return NPY_UINT32;
+ case kUInt64:
+ return NPY_UINT64;
case kBool:
return NPY_BOOL;
default:
@@ -320,6 +326,12 @@
return kChar;
case NPY_UINT8:
return kByte;
+ case NPY_UINT16:
+ return kUInt16;
+ case NPY_UINT32:
+ return kUInt32;
+ case NPY_UINT64:
+ return kUInt64;
case NPY_BOOL:
return kBool;
default:
@@ -346,7 +358,7 @@
throw python_error();
throw TypeError(
"can't convert np.ndarray of type %s. The only supported types are: "
- "float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool.",
+ "float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint64, uint32, uint16, uint8, and bool.",
((PyTypeObject*)pytype.get())->tp_name);
}
diff --git a/torch/testing/_comparison.py b/torch/testing/_comparison.py
index 1474501..3d405fe 100644
--- a/torch/testing/_comparison.py
+++ b/torch/testing/_comparison.py
@@ -800,7 +800,16 @@
expected = expected.cpu()
if actual.dtype != expected.dtype:
- dtype = torch.promote_types(actual.dtype, expected.dtype)
+ actual_dtype = actual.dtype
+ expected_dtype = expected.dtype
+ # For uint64, this is not sound in general, which is why promote_types doesn't
+ # allow it, but for easy testing, we're unlikely to get confused
+ # by large uint64 overflowing into negative int64
+ if actual_dtype in [torch.uint64, torch.uint32, torch.uint16]:
+ actual_dtype = torch.int64
+ if expected_dtype in [torch.uint64, torch.uint32, torch.uint16]:
+ expected_dtype = torch.int64
+ dtype = torch.promote_types(actual_dtype, expected_dtype)
actual = actual.to(dtype)
expected = expected.to(dtype)
diff --git a/torch/testing/_internal/common_dtype.py b/torch/testing/_internal/common_dtype.py
index 8d7d2bf..fca14ba 100644
--- a/torch/testing/_internal/common_dtype.py
+++ b/torch/testing/_internal/common_dtype.py
@@ -44,6 +44,7 @@
def double_types():
return _double_types
+# NB: Does not contain uint16/uint32/uint64 for BC reasons
_integral_types = _dispatch_dtypes((torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64))
def integral_types():
return _integral_types
diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py
index 6f7ea70..e912978 100644
--- a/torch/testing/_internal/common_utils.py
+++ b/torch/testing/_internal/common_utils.py
@@ -1491,6 +1491,9 @@
numpy_to_torch_dtype_dict = {
np.bool_ : torch.bool,
np.uint8 : torch.uint8,
+ np.uint16 : torch.uint16,
+ np.uint32 : torch.uint32,
+ np.uint64 : torch.uint64,
np.int8 : torch.int8,
np.int16 : torch.int16,
np.int32 : torch.int32,
diff --git a/torch/testing/_internal/dynamo_test_failures.py b/torch/testing/_internal/dynamo_test_failures.py
index 13de3df..c369bdd 100644
--- a/torch/testing/_internal/dynamo_test_failures.py
+++ b/torch/testing/_internal/dynamo_test_failures.py
@@ -598,7 +598,6 @@
"TestShuffle.test_1d_use_numpy_False",
"TestShuffle.test_2d_use_numpy_True",
"TestShuffle.test_2d_use_numpy_False",
- "TestArrayCreationCopyArgument.test_buffer_interface",
"TestWritebackIfCopy.test_take_mode_raise",
"TestArange.test_infinite",
"TestArrayConstruction.test_array_empty",
@@ -609,7 +608,6 @@
"TestFromBuffer.test_basic_little_dtype2",
"TestArrayCreationCopyArgument.test_striding_not_ok",
"TestArange.test_require_range",
- "TestStats.test_dtype_from_input",
"TestArange.test_nan_step",
"TestWritebackIfCopy.test_argmin_with_out",
"TestArrayAttributeDeletion.test_multiarray_not_writable_attributes_deletion",
@@ -2402,122 +2400,23 @@
"TestTensorBoardSummary.test_hparams_string", # test_tensorboard
"TestTensorBoardSummary.test_hparams_bool", # test_tensorboard
"TestTensorBoardSummary.test_uint8_image", # test_tensorboard
- "TestBufferProtocolCPU.test_shared_buffer_cpu_uint8", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_with_count_and_offset_cpu_complex128", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_same_type_cpu_complex128", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_invalid_positional_args_cpu_float16", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_requires_grad_cpu_int32", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_invalid_positional_args_cpu_uint8", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_with_count_cpu_complex128", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_requires_grad_cpu_uint8", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_with_count_cpu_uint8", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_with_count_cpu_float64", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_shared_buffer_cpu_complex128", # test_tensor_creation_ops
"TestAsArrayCPU.test_copy_list_cpu_float64", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_with_count_cpu_float32", # test_tensor_creation_ops
"TestAsArrayCPU.test_copy_list_cpu_int64", # test_tensor_creation_ops
- "TestAsArrayCPU.test_alias_from_buffer_cpu_float16", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_with_offset_cpu_int64", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_same_type_cpu_int16", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_with_offset_cpu_float64", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_invalid_positional_args_cpu_float64", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_same_type_cpu_float32", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_with_count_and_offset_cpu_uint8", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_with_count_and_offset_cpu_bool", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_same_type_cpu_int32", # test_tensor_creation_ops
- "TestAsArrayCPU.test_alias_from_buffer_cpu_int16", # test_tensor_creation_ops
- "TestAsArrayCPU.test_alias_from_buffer_cpu_float32", # test_tensor_creation_ops
"TestAsArrayCPU.test_copy_list_cpu_int32", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_requires_grad_cpu_int8", # test_tensor_creation_ops
- "TestAsArrayCPU.test_copy_from_buffer_cpu_int16", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_with_count_and_offset_cpu_int8", # test_tensor_creation_ops
- "TestAsArrayCPU.test_copy_from_buffer_cpu_int8", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_requires_grad_cpu_float16", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_invalid_positional_args_cpu_int64", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_same_type_cpu_int64", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_with_count_and_offset_cpu_float16", # test_tensor_creation_ops
- "TestAsArrayCPU.test_copy_from_buffer_cpu_complex64", # test_tensor_creation_ops
- "TestAsArrayCPU.test_alias_from_buffer_cpu_uint8", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_with_count_and_offset_cpu_int64", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_invalid_positional_args_cpu_complex128", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_with_count_and_offset_cpu_complex64", # test_tensor_creation_ops
- "TestAsArrayCPU.test_copy_from_buffer_cpu_uint8", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_invalid_positional_args_cpu_bool", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_with_count_cpu_int32", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_invalid_positional_args_cpu_int16", # test_tensor_creation_ops
- "TestAsArrayCPU.test_copy_from_buffer_cpu_complex128", # test_tensor_creation_ops
- "TestAsArrayCPU.test_copy_from_buffer_cpu_float32", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_with_offset_cpu_int16", # test_tensor_creation_ops
"TestAsArrayCPU.test_copy_list_cpu_float32", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_with_count_and_offset_cpu_float64", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_requires_grad_cpu_float32", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_same_type_cpu_uint8", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_with_count_and_offset_cpu_int16", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_shared_buffer_cpu_int8", # test_tensor_creation_ops
- "TestAsArrayCPU.test_copy_from_buffer_cpu_float64", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_with_count_cpu_int8", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_invalid_positional_args_cpu_int8", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_with_count_and_offset_cpu_float32", # test_tensor_creation_ops
- "TestAsArrayCPU.test_alias_from_buffer_cpu_complex64", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_same_type_cpu_float16", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_same_type_cpu_bool", # test_tensor_creation_ops
"TestTensorCreationCPU.test_tensor_factory_type_inference_cpu", # test_tensor_creation_ops
- "TestAsArrayCPU.test_copy_from_buffer_cpu_float16", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_with_count_cpu_bool", # test_tensor_creation_ops
"TestBufferProtocolCPU.test_byte_to_int_cpu", # test_tensor_creation_ops
"TestTensorCreationCPU.test_block_diag_cpu", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_with_offset_cpu_complex64", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_with_offset_cpu_uint8", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_with_count_cpu_int64", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_requires_grad_cpu_int16", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_requires_grad_cpu_complex128", # test_tensor_creation_ops
"TestAsArrayCPU.test_copy_list_cpu_int8", # test_tensor_creation_ops
- "TestAsArrayCPU.test_alias_from_buffer_cpu_complex128", # test_tensor_creation_ops
"TestAsArrayCPU.test_copy_list_cpu_float16", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_with_offset_cpu_complex128", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_with_offset_cpu_float32", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_with_offset_cpu_int32", # test_tensor_creation_ops
- "TestAsArrayCPU.test_copy_from_buffer_cpu_int32", # test_tensor_creation_ops
- "TestAsArrayCPU.test_copy_from_buffer_cpu_int64", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_with_offset_cpu_bool", # test_tensor_creation_ops
"TestAsArrayCPU.test_copy_list_cpu_complex64", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_with_count_and_offset_cpu_int32", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_shared_buffer_cpu_int64", # test_tensor_creation_ops
"TestAsArrayCPU.test_copy_list_cpu_uint8", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_with_offset_cpu_float16", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_with_offset_cpu_int8", # test_tensor_creation_ops
"TestAsArrayCPU.test_copy_list_cpu_bfloat16", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_invalid_positional_args_cpu_int32", # test_tensor_creation_ops
"TestAsArrayCPU.test_copy_list_cpu_bool", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_shared_buffer_cpu_bool", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_same_type_cpu_float64", # test_tensor_creation_ops
"TestTensorCreationCPU.test_constructor_dtypes_cpu", # test_tensor_creation_ops
- "TestAsArrayCPU.test_alias_from_buffer_cpu_int64", # test_tensor_creation_ops
- "TestAsArrayCPU.test_alias_from_buffer_cpu_bool", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_same_type_cpu_complex64", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_requires_grad_cpu_float64", # test_tensor_creation_ops
"TestAsArrayCPU.test_copy_list_cpu_complex128", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_same_type_cpu_int8", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_invalid_positional_args_cpu_float32", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_shared_buffer_cpu_int16", # test_tensor_creation_ops
"TestTensorCreationCPU.test_tensor_factory_copy_var_cpu", # test_tensor_creation_ops
"TestAsArrayCPU.test_copy_list_cpu_int16", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_requires_grad_cpu_int64", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_requires_grad_cpu_bool", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_shared_buffer_cpu_int32", # test_tensor_creation_ops
- "TestAsArrayCPU.test_alias_from_buffer_cpu_int32", # test_tensor_creation_ops
- "TestAsArrayCPU.test_copy_from_buffer_cpu_bool", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_with_count_cpu_complex64", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_shared_buffer_cpu_float64", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_requires_grad_cpu_complex64", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_with_count_cpu_int16", # test_tensor_creation_ops
- "TestAsArrayCPU.test_alias_from_buffer_cpu_float64", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_with_count_cpu_float16", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_invalid_positional_args_cpu_complex64", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_shared_buffer_cpu_complex64", # test_tensor_creation_ops
- "TestAsArrayCPU.test_alias_from_buffer_cpu_int8", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_shared_buffer_cpu_float32", # test_tensor_creation_ops
- "TestBufferProtocolCPU.test_shared_buffer_cpu_float16", # test_tensor_creation_ops
"TestTensorCreationCPU.test_cartesian_prod_cpu", # test_tensor_creation_ops
"TestSubclass.test_parametrization_non_wrapper_tensor_leave_parametrized_True", # test_subclass
"TestSubclass.test_module_optimization_non_wrapper_tensor", # test_subclass