Improve uint{16,32,64} dlpack/numpy compatibility (#116808)

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/116808
Approved by: https://github.com/malfet, https://github.com/albanD
diff --git a/aten/src/ATen/DLConvertor.cpp b/aten/src/ATen/DLConvertor.cpp
index 8f2cac8..8a74e1c 100644
--- a/aten/src/ATen/DLConvertor.cpp
+++ b/aten/src/ATen/DLConvertor.cpp
@@ -148,6 +148,15 @@
         case 8:
           stype = ScalarType::Byte;
           break;
+        case 16:
+          stype = ScalarType::UInt16;
+          break;
+        case 32:
+          stype = ScalarType::UInt32;
+          break;
+        case 64:
+          stype = ScalarType::UInt64;
+          break;
         default:
           TORCH_CHECK(
               false, "Unsupported kUInt bits " + c10::to_string(dtype.bits));
diff --git a/test/test_dlpack.py b/test/test_dlpack.py
index 35fe8ad..87a4657 100644
--- a/test/test_dlpack.py
+++ b/test/test_dlpack.py
@@ -15,7 +15,7 @@
 
     @skipMeta
     @onlyNativeDeviceTypes
-    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
+    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64))
     def test_dlpack_capsule_conversion(self, device, dtype):
         x = make_tensor((5,), dtype=dtype, device=device)
         z = from_dlpack(to_dlpack(x))
@@ -23,7 +23,7 @@
 
     @skipMeta
     @onlyNativeDeviceTypes
-    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
+    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64))
     def test_dlpack_protocol_conversion(self, device, dtype):
         x = make_tensor((5,), dtype=dtype, device=device)
         z = from_dlpack(x)
@@ -62,7 +62,7 @@
 
     @skipMeta
     @onlyNativeDeviceTypes
-    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
+    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64))
     def test_from_dlpack(self, device, dtype):
         x = make_tensor((5,), dtype=dtype, device=device)
         y = torch.from_dlpack(x)
@@ -70,7 +70,7 @@
 
     @skipMeta
     @onlyNativeDeviceTypes
-    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
+    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64))
     def test_from_dlpack_noncontinguous(self, device, dtype):
         x = make_tensor((25,), dtype=dtype, device=device).reshape(5, 5)
 
@@ -113,7 +113,7 @@
 
     @skipMeta
     @onlyNativeDeviceTypes
-    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))
+    @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool, torch.uint16, torch.uint32, torch.uint64))
     def test_from_dlpack_dtype(self, device, dtype):
         x = make_tensor((5,), dtype=dtype, device=device)
         y = torch.from_dlpack(x)
diff --git a/test/test_numpy_interop.py b/test/test_numpy_interop.py
index a0b9561..ced9192 100644
--- a/test/test_numpy_interop.py
+++ b/test/test_numpy_interop.py
@@ -220,7 +220,7 @@
                     self.assertEqual(tensor_from_array2[i], array2[i])
 
         # Test unsupported type
-        array = np.array([1, 2, 3, 4], dtype=np.uint16)
+        array = np.array(['foo', 'bar'], dtype=np.dtype(np.str_))
         with self.assertRaises(TypeError):
             tensor_from_array = torch.from_numpy(array)
 
@@ -417,7 +417,7 @@
     @onlyCPU
     def test_parse_numpy_int(self, device):
         # Only concrete class can be given where "Type[number[_64Bit]]" is expected
-        self.assertRaisesRegex(RuntimeError, "Overflow",
+        self.assertRaisesRegex(RuntimeError, "(Overflow|an integer is required)",
                                lambda: torch.mean(torch.randn(1, 1), np.uint64(-1)))  # type: ignore[call-overload]
         # https://github.com/pytorch/pytorch/issues/29252
         for nptype in [np.int16, np.int8, np.uint8, np.int32, np.int64]:
diff --git a/test/test_reductions.py b/test/test_reductions.py
index 45bff79..4295630 100644
--- a/test/test_reductions.py
+++ b/test/test_reductions.py
@@ -1666,41 +1666,43 @@
         def is_integral(dtype):
             return dtype in integral_types()
 
+        exact_dtype = True
         # On Windows CI, the current version of `numpy` promotes all lower integers
         # dtypes to int32 while `torch` promotes them to int64. Hence we skip on checking
         # the exact dtype.
         # Reference : https://dr.pytorch.org/api/view-log-full?build_id=122051580
         # PR : https://github.com/pytorch/pytorch/pull/38628#issuecomment-655905370
-        exact_dtype = False if (IS_WINDOWS and is_integral(dtype)) else True
-
+        if IS_WINDOWS and is_integral(dtype):
+            exact_dtype = False
+        # For uint8, numpy promotes to uint64 while torch promotes to int64.
+        # So we must skip this as well.
         if dtype == torch.uint8:
-            with self.assertRaises(TypeError):
-                self._test_reduction_function_with_numpy(torch_fn, np_fn, device, dtype, with_extremal=with_extremal)
+            exact_dtype = False
+
+        # TODO: Investigate why the output is not close to numpy.
+        if dtype == torch.float16:
+            atol = 0.4
+            rtol = 1e-2
+        elif dtype == torch.float32:
+            atol = 7e-05
+            rtol = 3e-06
         else:
-            # TODO: Investigate why the output is not close to numpy.
-            if dtype == torch.float16:
-                atol = 0.4
-                rtol = 1e-2
-            elif dtype == torch.float32:
-                atol = 7e-05
-                rtol = 3e-06
-            else:
-                # Default values
-                atol = None
-                rtol = None
-            self._test_reduction_function_with_numpy(torch_fn, np_fn, device, dtype,
-                                                     atol=atol, rtol=rtol, exact_dtype=exact_dtype,
-                                                     with_keepdim=with_keepdim, with_extremal=with_extremal)
+            # Default values
+            atol = None
+            rtol = None
+        self._test_reduction_function_with_numpy(torch_fn, np_fn, device, dtype,
+                                                 atol=atol, rtol=rtol, exact_dtype=exact_dtype,
+                                                 with_keepdim=with_keepdim, with_extremal=with_extremal)
 
     @onlyNativeDeviceTypes
-    @dtypes(*all_types_and(torch.half))
+    @dtypes(*set(all_types_and(torch.half)) - {torch.uint8})
     def test_sum_vs_numpy(self, device, dtype):
         self._test_sum_reduction_vs_numpy(torch.sum, np.sum, device, dtype)
         self._test_sum_reduction_vs_numpy(torch.sum, np.sum, device, dtype, with_extremal=True)
         self._test_sum_reduction_vs_numpy(torch.sum, np.sum, device, dtype, with_keepdim=True)
 
     @onlyNativeDeviceTypes
-    @dtypes(*all_types_and(torch.half))
+    @dtypes(*set(all_types_and(torch.half)) - {torch.uint8})
     def test_nansum_vs_numpy(self, device, dtype):
         self._test_sum_reduction_vs_numpy(torch.nansum, np.nansum, device, dtype)
         self._test_sum_reduction_vs_numpy(torch.nansum, np.nansum, device, dtype, with_extremal=True)
@@ -3528,6 +3530,10 @@
             # Workaround https://github.com/pytorch/pytorch/issues/66556
             expected = np.asarray(expected)  # transform numpy scalars to numpy.ndarray instances
 
+            # Numpy differs, producing uint32 on Windows
+            if expected.dtype in [np.uint64, np.uint32]:
+                exact_dtype = False
+
             msg = ("Failed to produce expected results! Input tensor was"
                    f" {t}, torch result is {actual}, and reference result is"
                    f" {expected}.") if t.numel() < 10 else None
diff --git a/test/test_tensor_creation_ops.py b/test/test_tensor_creation_ops.py
index bfc7e51..1a3adb4 100644
--- a/test/test_tensor_creation_ops.py
+++ b/test/test_tensor_creation_ops.py
@@ -543,10 +543,11 @@
             self.assertEqual(y, expected_y)
             self.assertEqual(x, expected_x)
 
-    @dtypes(*all_types_and_complex())
+    @dtypes(*all_types_and_complex(), torch.uint16, torch.uint32, torch.uint64)
     def test_cat_out_fast_path_dim0_dim1(self, device, dtype):
+        int_types = integral_types_and(torch.uint16, torch.uint32, torch.uint64)
         x = torch.zeros((0), device=device, dtype=dtype)
-        if dtype in integral_types():
+        if dtype in int_types:
             y = torch.randint(low=0, high=100, size=(4, 6), device=device, dtype=dtype)
         else:
             y = torch.randn((4, 6), device=device, dtype=dtype)
@@ -575,7 +576,7 @@
         self.assertEqual(b_fastcat, expected_b)
         # Finally, we need to make sure backward is not broken
         # Integral types will not have grad
-        if dtype not in integral_types():
+        if dtype not in int_types:
             a = torch.randn((4, 3), device=device, dtype=dtype, requires_grad=True)
             b = torch.randn((2, 3), device=device, dtype=dtype, requires_grad=True)
             c = torch.randn((5, 3), device=device, dtype=dtype, requires_grad=True)
@@ -1046,6 +1047,8 @@
         self._float_to_int_conversion_helper(vals, device, dtype, refs)
 
     # Note: CUDA will fail this test on most dtypes, often dramatically.
+    # NB: torch.uint16, torch.uint32, torch.uint64 excluded as this
+    # nondeterministically fails, warning "invalid value encountered in cast"
     @onlyCPU
     @dtypes(torch.bool, torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64)
     def test_float_to_int_conversion_nonfinite(self, device, dtype):
diff --git a/test/test_type_promotion.py b/test/test_type_promotion.py
index ceef13b..dd96bd2 100644
--- a/test/test_type_promotion.py
+++ b/test/test_type_promotion.py
@@ -941,8 +941,10 @@
     @unittest.skipIf(not TEST_NUMPY, "NumPy not found")
     @float_double_default_dtype
     @onlyCPU
-    @dtypes(*list(itertools.product(set(numpy_to_torch_dtype_dict.values()),
-                                    set(numpy_to_torch_dtype_dict.values()))))
+    # NB: skip uint16,32,64 as PyTorch doesn't implement promotion for them
+    @dtypes(*list(itertools.product(
+        set(numpy_to_torch_dtype_dict.values()) - {torch.uint16, torch.uint32, torch.uint64},
+        set(numpy_to_torch_dtype_dict.values()) - {torch.uint16, torch.uint32, torch.uint64})))
     def test_numpy_array_binary_ufunc_promotion(self, device, dtypes):
         import operator
         np_type = torch_to_numpy_dtype_dict[dtypes[0]]
diff --git a/torch/_dynamo/utils.py b/torch/_dynamo/utils.py
index b684107..d32cf9c 100644
--- a/torch/_dynamo/utils.py
+++ b/torch/_dynamo/utils.py
@@ -423,6 +423,9 @@
         return True
     except TypeError:
         return False
+    # cannot hash writable memoryview object
+    except ValueError:
+        return False
 
 
 def nothing(*args, **kwargs):
diff --git a/torch/csrc/utils/tensor_numpy.cpp b/torch/csrc/utils/tensor_numpy.cpp
index 95147c7..ed36d96 100644
--- a/torch/csrc/utils/tensor_numpy.cpp
+++ b/torch/csrc/utils/tensor_numpy.cpp
@@ -295,6 +295,12 @@
       return NPY_INT8;
     case kByte:
       return NPY_UINT8;
+    case kUInt16:
+      return NPY_UINT16;
+    case kUInt32:
+      return NPY_UINT32;
+    case kUInt64:
+      return NPY_UINT64;
     case kBool:
       return NPY_BOOL;
     default:
@@ -320,6 +326,12 @@
       return kChar;
     case NPY_UINT8:
       return kByte;
+    case NPY_UINT16:
+      return kUInt16;
+    case NPY_UINT32:
+      return kUInt32;
+    case NPY_UINT64:
+      return kUInt64;
     case NPY_BOOL:
       return kBool;
     default:
@@ -346,7 +358,7 @@
     throw python_error();
   throw TypeError(
       "can't convert np.ndarray of type %s. The only supported types are: "
-      "float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint8, and bool.",
+      "float64, float32, float16, complex64, complex128, int64, int32, int16, int8, uint64, uint32, uint16, uint8, and bool.",
       ((PyTypeObject*)pytype.get())->tp_name);
 }
 
diff --git a/torch/testing/_comparison.py b/torch/testing/_comparison.py
index 1474501..3d405fe 100644
--- a/torch/testing/_comparison.py
+++ b/torch/testing/_comparison.py
@@ -800,7 +800,16 @@
             expected = expected.cpu()
 
         if actual.dtype != expected.dtype:
-            dtype = torch.promote_types(actual.dtype, expected.dtype)
+            actual_dtype = actual.dtype
+            expected_dtype = expected.dtype
+            # For uint64, this is not sound in general, which is why promote_types doesn't
+            # allow it, but for easy testing, we're unlikely to get confused
+            # by large uint64 overflowing into negative int64
+            if actual_dtype in [torch.uint64, torch.uint32, torch.uint16]:
+                actual_dtype = torch.int64
+            if expected_dtype in [torch.uint64, torch.uint32, torch.uint16]:
+                expected_dtype = torch.int64
+            dtype = torch.promote_types(actual_dtype, expected_dtype)
             actual = actual.to(dtype)
             expected = expected.to(dtype)
 
diff --git a/torch/testing/_internal/common_dtype.py b/torch/testing/_internal/common_dtype.py
index 8d7d2bf..fca14ba 100644
--- a/torch/testing/_internal/common_dtype.py
+++ b/torch/testing/_internal/common_dtype.py
@@ -44,6 +44,7 @@
 def double_types():
     return _double_types
 
+# NB: Does not contain uint16/uint32/uint64 for BC reasons
 _integral_types = _dispatch_dtypes((torch.uint8, torch.int8, torch.int16, torch.int32, torch.int64))
 def integral_types():
     return _integral_types
diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py
index 6f7ea70..e912978 100644
--- a/torch/testing/_internal/common_utils.py
+++ b/torch/testing/_internal/common_utils.py
@@ -1491,6 +1491,9 @@
 numpy_to_torch_dtype_dict = {
     np.bool_      : torch.bool,
     np.uint8      : torch.uint8,
+    np.uint16     : torch.uint16,
+    np.uint32     : torch.uint32,
+    np.uint64     : torch.uint64,
     np.int8       : torch.int8,
     np.int16      : torch.int16,
     np.int32      : torch.int32,
diff --git a/torch/testing/_internal/dynamo_test_failures.py b/torch/testing/_internal/dynamo_test_failures.py
index 13de3df..c369bdd 100644
--- a/torch/testing/_internal/dynamo_test_failures.py
+++ b/torch/testing/_internal/dynamo_test_failures.py
@@ -598,7 +598,6 @@
     "TestShuffle.test_1d_use_numpy_False",
     "TestShuffle.test_2d_use_numpy_True",
     "TestShuffle.test_2d_use_numpy_False",
-    "TestArrayCreationCopyArgument.test_buffer_interface",
     "TestWritebackIfCopy.test_take_mode_raise",
     "TestArange.test_infinite",
     "TestArrayConstruction.test_array_empty",
@@ -609,7 +608,6 @@
     "TestFromBuffer.test_basic_little_dtype2",
     "TestArrayCreationCopyArgument.test_striding_not_ok",
     "TestArange.test_require_range",
-    "TestStats.test_dtype_from_input",
     "TestArange.test_nan_step",
     "TestWritebackIfCopy.test_argmin_with_out",
     "TestArrayAttributeDeletion.test_multiarray_not_writable_attributes_deletion",
@@ -2402,122 +2400,23 @@
     "TestTensorBoardSummary.test_hparams_string",  # test_tensorboard
     "TestTensorBoardSummary.test_hparams_bool",  # test_tensorboard
     "TestTensorBoardSummary.test_uint8_image",  # test_tensorboard
-    "TestBufferProtocolCPU.test_shared_buffer_cpu_uint8",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_with_count_and_offset_cpu_complex128",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_same_type_cpu_complex128",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_invalid_positional_args_cpu_float16",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_requires_grad_cpu_int32",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_invalid_positional_args_cpu_uint8",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_with_count_cpu_complex128",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_requires_grad_cpu_uint8",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_with_count_cpu_uint8",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_with_count_cpu_float64",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_shared_buffer_cpu_complex128",  # test_tensor_creation_ops
     "TestAsArrayCPU.test_copy_list_cpu_float64",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_with_count_cpu_float32",  # test_tensor_creation_ops
     "TestAsArrayCPU.test_copy_list_cpu_int64",  # test_tensor_creation_ops
-    "TestAsArrayCPU.test_alias_from_buffer_cpu_float16",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_with_offset_cpu_int64",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_same_type_cpu_int16",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_with_offset_cpu_float64",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_invalid_positional_args_cpu_float64",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_same_type_cpu_float32",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_with_count_and_offset_cpu_uint8",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_with_count_and_offset_cpu_bool",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_same_type_cpu_int32",  # test_tensor_creation_ops
-    "TestAsArrayCPU.test_alias_from_buffer_cpu_int16",  # test_tensor_creation_ops
-    "TestAsArrayCPU.test_alias_from_buffer_cpu_float32",  # test_tensor_creation_ops
     "TestAsArrayCPU.test_copy_list_cpu_int32",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_requires_grad_cpu_int8",  # test_tensor_creation_ops
-    "TestAsArrayCPU.test_copy_from_buffer_cpu_int16",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_with_count_and_offset_cpu_int8",  # test_tensor_creation_ops
-    "TestAsArrayCPU.test_copy_from_buffer_cpu_int8",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_requires_grad_cpu_float16",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_invalid_positional_args_cpu_int64",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_same_type_cpu_int64",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_with_count_and_offset_cpu_float16",  # test_tensor_creation_ops
-    "TestAsArrayCPU.test_copy_from_buffer_cpu_complex64",  # test_tensor_creation_ops
-    "TestAsArrayCPU.test_alias_from_buffer_cpu_uint8",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_with_count_and_offset_cpu_int64",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_invalid_positional_args_cpu_complex128",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_with_count_and_offset_cpu_complex64",  # test_tensor_creation_ops
-    "TestAsArrayCPU.test_copy_from_buffer_cpu_uint8",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_invalid_positional_args_cpu_bool",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_with_count_cpu_int32",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_invalid_positional_args_cpu_int16",  # test_tensor_creation_ops
-    "TestAsArrayCPU.test_copy_from_buffer_cpu_complex128",  # test_tensor_creation_ops
-    "TestAsArrayCPU.test_copy_from_buffer_cpu_float32",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_with_offset_cpu_int16",  # test_tensor_creation_ops
     "TestAsArrayCPU.test_copy_list_cpu_float32",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_with_count_and_offset_cpu_float64",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_requires_grad_cpu_float32",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_same_type_cpu_uint8",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_with_count_and_offset_cpu_int16",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_shared_buffer_cpu_int8",  # test_tensor_creation_ops
-    "TestAsArrayCPU.test_copy_from_buffer_cpu_float64",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_with_count_cpu_int8",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_invalid_positional_args_cpu_int8",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_with_count_and_offset_cpu_float32",  # test_tensor_creation_ops
-    "TestAsArrayCPU.test_alias_from_buffer_cpu_complex64",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_same_type_cpu_float16",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_same_type_cpu_bool",  # test_tensor_creation_ops
     "TestTensorCreationCPU.test_tensor_factory_type_inference_cpu",  # test_tensor_creation_ops
-    "TestAsArrayCPU.test_copy_from_buffer_cpu_float16",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_with_count_cpu_bool",  # test_tensor_creation_ops
     "TestBufferProtocolCPU.test_byte_to_int_cpu",  # test_tensor_creation_ops
     "TestTensorCreationCPU.test_block_diag_cpu",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_with_offset_cpu_complex64",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_with_offset_cpu_uint8",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_with_count_cpu_int64",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_requires_grad_cpu_int16",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_requires_grad_cpu_complex128",  # test_tensor_creation_ops
     "TestAsArrayCPU.test_copy_list_cpu_int8",  # test_tensor_creation_ops
-    "TestAsArrayCPU.test_alias_from_buffer_cpu_complex128",  # test_tensor_creation_ops
     "TestAsArrayCPU.test_copy_list_cpu_float16",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_with_offset_cpu_complex128",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_with_offset_cpu_float32",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_with_offset_cpu_int32",  # test_tensor_creation_ops
-    "TestAsArrayCPU.test_copy_from_buffer_cpu_int32",  # test_tensor_creation_ops
-    "TestAsArrayCPU.test_copy_from_buffer_cpu_int64",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_with_offset_cpu_bool",  # test_tensor_creation_ops
     "TestAsArrayCPU.test_copy_list_cpu_complex64",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_with_count_and_offset_cpu_int32",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_shared_buffer_cpu_int64",  # test_tensor_creation_ops
     "TestAsArrayCPU.test_copy_list_cpu_uint8",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_with_offset_cpu_float16",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_with_offset_cpu_int8",  # test_tensor_creation_ops
     "TestAsArrayCPU.test_copy_list_cpu_bfloat16",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_invalid_positional_args_cpu_int32",  # test_tensor_creation_ops
     "TestAsArrayCPU.test_copy_list_cpu_bool",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_shared_buffer_cpu_bool",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_same_type_cpu_float64",  # test_tensor_creation_ops
     "TestTensorCreationCPU.test_constructor_dtypes_cpu",  # test_tensor_creation_ops
-    "TestAsArrayCPU.test_alias_from_buffer_cpu_int64",  # test_tensor_creation_ops
-    "TestAsArrayCPU.test_alias_from_buffer_cpu_bool",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_same_type_cpu_complex64",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_requires_grad_cpu_float64",  # test_tensor_creation_ops
     "TestAsArrayCPU.test_copy_list_cpu_complex128",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_same_type_cpu_int8",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_invalid_positional_args_cpu_float32",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_shared_buffer_cpu_int16",  # test_tensor_creation_ops
     "TestTensorCreationCPU.test_tensor_factory_copy_var_cpu",  # test_tensor_creation_ops
     "TestAsArrayCPU.test_copy_list_cpu_int16",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_requires_grad_cpu_int64",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_requires_grad_cpu_bool",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_shared_buffer_cpu_int32",  # test_tensor_creation_ops
-    "TestAsArrayCPU.test_alias_from_buffer_cpu_int32",  # test_tensor_creation_ops
-    "TestAsArrayCPU.test_copy_from_buffer_cpu_bool",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_with_count_cpu_complex64",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_shared_buffer_cpu_float64",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_requires_grad_cpu_complex64",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_with_count_cpu_int16",  # test_tensor_creation_ops
-    "TestAsArrayCPU.test_alias_from_buffer_cpu_float64",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_with_count_cpu_float16",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_invalid_positional_args_cpu_complex64",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_shared_buffer_cpu_complex64",  # test_tensor_creation_ops
-    "TestAsArrayCPU.test_alias_from_buffer_cpu_int8",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_shared_buffer_cpu_float32",  # test_tensor_creation_ops
-    "TestBufferProtocolCPU.test_shared_buffer_cpu_float16",  # test_tensor_creation_ops
     "TestTensorCreationCPU.test_cartesian_prod_cpu",  # test_tensor_creation_ops
     "TestSubclass.test_parametrization_non_wrapper_tensor_leave_parametrized_True",  # test_subclass
     "TestSubclass.test_module_optimization_non_wrapper_tensor",  # test_subclass