| # Owner(s): ["module: tests"] | 
 | import torch | 
 | import numpy as np | 
 |  | 
 | import unittest | 
 | from itertools import product, permutations, combinations | 
 | from functools import partial | 
 | import random | 
 |  | 
 | from torch.testing import make_tensor | 
 | from torch.testing._internal.common_utils import ( | 
 |     IS_FBCODE, TestCase, run_tests, suppress_warnings, gradcheck, gradgradcheck, | 
 |     numpy_to_torch_dtype_dict, skipIfTorchDynamo | 
 | ) | 
 | from torch.testing._internal.common_device_type import \ | 
 |     (instantiate_device_type_tests, onlyCPU, dtypes, onlyNativeDeviceTypes, skipMeta) | 
 | from torch.testing._internal.common_dtype import ( | 
 |     all_types_and_complex_and, complex_types, all_types_and, floating_and_complex_types_and, | 
 | ) | 
 |  | 
 | # TODO: replace this with make_tensor() in common_utils.py | 
 | def _generate_input(shape, dtype, device, with_extremal): | 
 |     if shape == (): | 
 |         x = torch.tensor((), dtype=dtype, device=device) | 
 |     else: | 
 |         if dtype.is_floating_point or dtype.is_complex: | 
 |             # work around torch.randn not being implemented for bfloat16 | 
 |             if dtype == torch.bfloat16: | 
 |                 x = torch.randn(*shape, device=device) * random.randint(30, 100) | 
 |                 x = x.to(torch.bfloat16) | 
 |             else: | 
 |                 x = torch.randn(*shape, dtype=dtype, device=device) * random.randint(30, 100) | 
 |             x[torch.randn(*shape) > 0.5] = 0 | 
 |             if with_extremal and dtype.is_floating_point: | 
 |                 # Use extremal values | 
 |                 x[torch.randn(*shape) > 0.5] = float('nan') | 
 |                 x[torch.randn(*shape) > 0.5] = float('inf') | 
 |                 x[torch.randn(*shape) > 0.5] = float('-inf') | 
 |             elif with_extremal and dtype.is_complex: | 
 |                 x[torch.randn(*shape) > 0.5] = complex('nan') | 
 |                 x[torch.randn(*shape) > 0.5] = complex('inf') | 
 |                 x[torch.randn(*shape) > 0.5] = complex('-inf') | 
 |         elif dtype == torch.bool: | 
 |             x = torch.zeros(shape, dtype=dtype, device=device) | 
 |             x[torch.randn(*shape) > 0.5] = True | 
 |         else: | 
 |             x = torch.randint(15, 100, shape, dtype=dtype, device=device) | 
 |  | 
 |     return x | 
 |  | 
 | # TODO: replace this with make_tensor() in common_utils.py | 
 | def _rand_shape(dim, min_size, max_size): | 
 |     shape = [] | 
 |     for i in range(dim): | 
 |         shape.append(random.randint(min_size, max_size)) | 
 |     return tuple(shape) | 
 |  | 
 | # TODO: refactor tests to avoid this function | 
 | # Converts half/bfloat16 dtype to float when device is cpu | 
 | def _convert_t(dtype, device): | 
 |     if device == 'cpu' and dtype in {torch.half, torch.bfloat16}: | 
 |         return torch.float | 
 |     return dtype | 
 |  | 
 | # TODO: replace this with make_tensor() in common_utils.py | 
 | # Returns a tensor of the requested shape, dtype, and device | 
 | # Requesting a half CPU tensor returns a float CPU tensor with | 
 | # values representable by a half. | 
 | # Initialization uses randint for non-float types and randn for float types. | 
 | def _make_tensor(shape, dtype, device, fill_ones=False) -> torch.Tensor: | 
 |     # Returns a tensor filled with ones | 
 |     if fill_ones: | 
 |         return torch.ones(*shape, dtype=_convert_t(dtype, device), device=device) | 
 |  | 
 |     # Returns a tensor with random integer values | 
 |     if not (dtype.is_floating_point or dtype.is_complex): | 
 |         t = torch.randint(0, 10, shape, device=device) | 
 |         if dtype != torch.uint8: | 
 |             t = t - 5  # generate negative values also | 
 |         return t.to(_convert_t(dtype, device)) | 
 |  | 
 |     # Populates the CPU tensor with floats representable as half/bfloat16 | 
 |     if dtype == torch.half and device == 'cpu': | 
 |         return torch.randn(*shape, dtype=torch.float, device=device).half().float() | 
 |     if dtype == torch.bfloat16 and device == 'cpu': | 
 |         return torch.randn(*shape, dtype=torch.float, device=device).bfloat16().float() | 
 |  | 
 |     # Default: returns a tensor with random float values | 
 |     return torch.randn(shape, dtype=dtype, device=device).to(dtype=dtype) | 
 |  | 
 | # Tests ops and indexing to ensure they return views (and new tensors) as | 
 | # appropriate. | 
 | class TestViewOps(TestCase): | 
 |     exact_dtype = True | 
 |  | 
 |     def is_view_of(self, base, other): | 
 |         if (not other._is_view() or | 
 |                 other is base or | 
 |                 other._base is not base or | 
 |                 base.device != other.device): | 
 |             return False | 
 |         # Note: only validates storage on native device types | 
 |         # because some accelerators, like XLA, do not expose storage | 
 |         if base.device.type == 'cpu' or base.device.type == 'cuda': | 
 |             if base.untyped_storage().data_ptr() != other.untyped_storage().data_ptr(): | 
 |                 return False | 
 |  | 
 |         return True | 
 |  | 
 |     # Returns true if v1 and v2 are views of the same base | 
 |     def is_view_of_same_base(self, v1, v2): | 
 |         if (not v1._is_view() or v1 is v2): | 
 |             return False | 
 |         return self.is_view_of(v1._base, v2) | 
 |  | 
 |     # Performs transpose if contiguous=True, else returns the input tensor as is | 
 |     def _do_transpose(self, x, contiguous=False, dim0=0, dim1=1): | 
 |         if contiguous: | 
 |             return x | 
 |         else: | 
 |             return x.transpose(dim0, dim1) | 
 |  | 
 |     @dtypes(*all_types_and(torch.half, torch.bfloat16)) | 
 |     def test_conj_self(self, device, dtype): | 
 |         t = torch.ones(5, 5, device=device) | 
 |         s = t.conj() | 
 |         self.assertTrue(s is t) | 
 |  | 
 |     @skipIfTorchDynamo("TorchDynamo fails with unknown reason") | 
 |     @onlyNativeDeviceTypes | 
 |     @dtypes(*all_types_and_complex_and(torch.half, torch.bool)) | 
 |     def test_view_dtype_new(self, device, dtype): | 
 |         dtypes = {value : key for (key, value) in numpy_to_torch_dtype_dict.items()} | 
 |         del dtypes[torch.bool] | 
 |  | 
 |         def generate_inputs(): | 
 |             yield make_tensor((4, 4, 64), dtype=dtype, device=device, low=-5, high=5) | 
 |             yield make_tensor((4, 4, 64), dtype=dtype, device=device, low=-5, high=5).permute(1, 0, 2) | 
 |             yield make_tensor((4, 64, 4), dtype=dtype, device=device, low=-5, high=5).permute(2, 0, 1) | 
 |             yield make_tensor((1, 5, 1), dtype=dtype, device=device, low=-5, high=5).expand(5, 5, 64) | 
 |             yield make_tensor((2, 5, 256), dtype=dtype, device=device, low=-5, high=5)[1::2, 1:, ::2] | 
 |             yield make_tensor((0, 5, 64), dtype=dtype, device=device, low=-5, high=5) | 
 |             yield make_tensor((), dtype=dtype, device=device, low=-5, high=5) | 
 |  | 
 |         def calc_expected_size_and_stride(a, view_dtype): | 
 |             dtype_size = torch._utils._element_size(a.dtype) | 
 |             view_dtype_size = torch._utils._element_size(view_dtype) | 
 |  | 
 |             if dtype_size == view_dtype_size: | 
 |                 return a.size(), a.stride() | 
 |  | 
 |             elif dtype_size > view_dtype_size: | 
 |                 size_ratio = dtype_size // view_dtype_size | 
 |  | 
 |                 view_size = list(a.size()) | 
 |                 view_size[-1] = view_size[-1] * size_ratio | 
 |  | 
 |                 view_stride = [stride * size_ratio for stride in a.stride()] | 
 |                 view_stride[-1] = 1 | 
 |                 return torch.Size(view_size), tuple(view_stride) | 
 |  | 
 |             else: | 
 |                 size_ratio = view_dtype_size // dtype_size | 
 |  | 
 |                 view_size = list(a.size()) | 
 |                 view_size[-1] = view_size[-1] // size_ratio | 
 |  | 
 |                 view_stride = [stride // size_ratio for stride in a.stride()] | 
 |                 view_stride[-1] = 1 | 
 |                 return torch.Size(view_size), tuple(view_stride) | 
 |  | 
 |         for a in generate_inputs(): | 
 |             a_np = a.cpu().numpy() | 
 |             a_np_contiguous = a.cpu().contiguous().numpy() | 
 |  | 
 |             for view_dtype, np_view_dtype in dtypes.items(): | 
 |                 equal_element_size = torch._utils._element_size(dtype) == torch._utils._element_size(view_dtype) | 
 |  | 
 |                 if not equal_element_size and a.dim() == 0: | 
 |                     with self.assertRaisesRegex(RuntimeError, r"self.dim\(\) cannot be 0"): | 
 |                         a.view(view_dtype) | 
 |                     continue | 
 |  | 
 |                 if not equal_element_size and a.stride(-1) != 1: | 
 |                     with self.assertRaisesRegex(RuntimeError, r"self.stride\(-1\) must be 1"): | 
 |                         a.view(view_dtype) | 
 |                     continue | 
 |  | 
 |                 a_view = a.view(view_dtype) | 
 |                 self.assertEqual(a_view.dtype, view_dtype) | 
 |                 self.assertEqual(a.data_ptr(), a_view.data_ptr()) | 
 |  | 
 |                 expected_size, expected_stride = calc_expected_size_and_stride(a, view_dtype) | 
 |                 self.assertEqual(a_view.size(), expected_size) | 
 |                 self.assertEqual(a_view.stride(), expected_stride) | 
 |  | 
 |                 self.assertEqual(a_view.view(dtype), a, rtol=0, atol=0) | 
 |  | 
 |                 # NumPy's dtype view requires contiguous input if target | 
 |                 # dtype is a different size | 
 |                 if equal_element_size: | 
 |                     a_np_view = a_np.view(np_view_dtype) | 
 |  | 
 |                 else: | 
 |                     a_np_view = a_np_contiguous.view(np_view_dtype) | 
 |  | 
 |                 self.assertEqual(a_view, a_np_view) | 
 |  | 
 |         # Test that requires_grad is dropped for floating point casts, | 
 |         # because view(dtype) does not support backward yet | 
 |         # TODO: Remove this when autograd support is added | 
 |         if dtype.is_floating_point or dtype.is_complex: | 
 |             for view_dtype in floating_and_complex_types_and(torch.half, torch.bfloat16): | 
 |                 t = make_tensor((5, 5, 64), dtype=dtype, device=device, low=-5, high=5, requires_grad=True) | 
 |                 self.assertFalse(t.view(view_dtype).requires_grad) | 
 |  | 
 |     # Test the extra error checks that happen when the view dtype | 
 |     # has a greater element size than the original dtype | 
 |     @onlyNativeDeviceTypes | 
 |     @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) | 
 |     def test_view_dtype_upsize_errors(self, device, dtype): | 
 |         dtype_size = torch._utils._element_size(dtype) | 
 |  | 
 |         for view_dtype in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool): | 
 |             view_dtype_size = torch._utils._element_size(view_dtype) | 
 |             if view_dtype_size <= dtype_size: | 
 |                 continue | 
 |  | 
 |             size_ratio = view_dtype_size // dtype_size | 
 |             a = make_tensor((4, 4, size_ratio + 1), dtype=dtype, device=device, low=-5, high=5) | 
 |             with self.assertRaisesRegex( | 
 |                     RuntimeError, | 
 |                     rf"self.size\(-1\) must be divisible by {size_ratio}"): | 
 |                 a.view(view_dtype) | 
 |  | 
 |             with self.assertRaisesRegex( | 
 |                     RuntimeError, | 
 |                     rf"self.storage_offset\(\) must be divisible by {size_ratio}"): | 
 |                 a[:, :, 1:].view(view_dtype) | 
 |  | 
 |             a = make_tensor((4, 4, size_ratio), dtype=dtype, device=device, low=-5, high=5) | 
 |             a = a.as_strided((4, 4, size_ratio), (size_ratio, 1, 1)) | 
 |             with self.assertRaisesRegex( | 
 |                     RuntimeError, | 
 |                     rf"self.stride\(1\) must be divisible by {size_ratio}"): | 
 |                 a.view(view_dtype) | 
 |  | 
 |     @onlyNativeDeviceTypes | 
 |     def test_view_as_complex(self, device): | 
 |         def fn(contiguous_input=True, dim0=0, dim1=1): | 
 |             t = torch.randn(3, 2, 2, device=device) | 
 |             c_t = t[:, :, 0] + 1j * t[:, :, 1] | 
 |  | 
 |             input = self._do_transpose(t, contiguous_input, dim0, dim1) | 
 |  | 
 |             if input.size()[-1] != 2: | 
 |                 self.assertRaisesRegex( | 
 |                     RuntimeError, "Tensor must have a last dimension of size 2", | 
 |                     lambda: torch.view_as_complex(input)) | 
 |                 return | 
 |  | 
 |             if input.stride()[-1] != 1: | 
 |                 self.assertRaisesRegex( | 
 |                     RuntimeError, "Tensor must have a last dimension with stride 1", | 
 |                     lambda: torch.view_as_complex(input)) | 
 |                 return | 
 |  | 
 |             res = torch.view_as_complex(input) | 
 |             self.assertEqual(res, self._do_transpose(c_t, contiguous_input, dim0, dim1)) | 
 |             self.assertTrue(self.is_view_of(t, res)) | 
 |  | 
 |         fn() | 
 |         fn(contiguous_input=False) | 
 |         # RuntimeError since in this case the last dim of input would not be of size 2 | 
 |         fn(contiguous_input=False, dim0=0, dim1=2) | 
 |         # RuntimeError since in this case the last dim of input would not have stride 1 | 
 |         fn(contiguous_input=False, dim0=1, dim1=2) | 
 |  | 
 |  | 
 |         # RuntimeError since in this case the stride of non-last dim of input would not be of size 2 | 
 |         x = torch.randn(3, 3, device=device) | 
 |         t = torch.as_strided(x, (2, 2), (1, 1)) | 
 |         self.assertRaisesRegex( | 
 |             RuntimeError, "Tensor must have a stride divisible by 2 for all but last dimension", | 
 |             lambda: torch.view_as_complex(t)) | 
 |  | 
 |         # tensor with zero elements | 
 |         x = torch.tensor([], device=device)  # torch.Size([0]) | 
 |         self.assertRaisesRegex( | 
 |             RuntimeError, "Tensor must have a last dimension of size 2", | 
 |             lambda: torch.view_as_complex(x)) | 
 |  | 
 |         # zero dimension tensor | 
 |         z = torch.tensor(2.0) | 
 |         self.assertRaisesRegex( | 
 |             RuntimeError, "Input tensor must have one or more dimensions", | 
 |             lambda: torch.view_as_complex(z)) | 
 |  | 
 |         y = x.reshape(0, 2)  # torch.Size([0, 2]) | 
 |         res = torch.view_as_complex(y) | 
 |         self.assertTrue(self.is_view_of(x, res)) | 
 |         self.assertEqual(res.shape, torch.Size([0])) | 
 |  | 
 |     @onlyNativeDeviceTypes | 
 |     @dtypes(*complex_types(), torch.complex32) | 
 |     def test_view_as_real(self, device, dtype): | 
 |         def fn(contiguous_input=True): | 
 |             t = torch.randn(3, 4, dtype=dtype, device=device) | 
 |             input = self._do_transpose(t, contiguous_input) | 
 |             res = torch.view_as_real(input) | 
 |             self.assertEqual(res[:, :, 0], input.real) | 
 |             self.assertEqual(res[:, :, 1], input.imag) | 
 |             self.assertTrue(self.is_view_of(t, res)) | 
 |  | 
 |         fn() | 
 |         fn(contiguous_input=False) | 
 |  | 
 |         # tensor with zero elements | 
 |         x = torch.tensor([], dtype=dtype, device=device) | 
 |         res = torch.view_as_real(x) | 
 |         self.assertTrue(self.is_view_of(x, res)) | 
 |         self.assertEqual(res.shape, torch.Size([0, 2])) | 
 |  | 
 |         # tensor with zero dim | 
 |         x = torch.tensor(2 + 3j, dtype=dtype, device=device) | 
 |         res = torch.view_as_real(x) | 
 |         self.assertTrue(self.is_view_of(x, res)) | 
 |         self.assertEqual(res.shape, torch.Size([2])) | 
 |  | 
 |     @onlyNativeDeviceTypes | 
 |     @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) | 
 |     def test_view_tensor_split(self, device, dtype): | 
 |         a = make_tensor((40, 30), dtype=dtype, device=device, low=-9, high=9) | 
 |         a_split_dim0 = a.tensor_split(7, 0) | 
 |         for a_split_dim0_tensor in a_split_dim0: | 
 |             self.assertTrue(self.is_view_of(a, a_split_dim0_tensor)) | 
 |         a_split_dim1 = a.tensor_split(7, 1) | 
 |         for a_split_dim1_tensor in a_split_dim1: | 
 |             self.assertTrue(self.is_view_of(a, a_split_dim1_tensor)) | 
 |  | 
 |     @onlyNativeDeviceTypes | 
 |     @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) | 
 |     def test_view_tensor_hsplit(self, device, dtype): | 
 |         t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9) | 
 |         t_hsplit = torch.hsplit(t, 2) | 
 |         for t_hsplit_tensor in t_hsplit: | 
 |             self.assertTrue(self.is_view_of(t, t_hsplit_tensor)) | 
 |         t[2, 2, 2] = 7 | 
 |         self.assertEqual(t_hsplit[1][2, 0, 2], t[2, 2, 2]) | 
 |  | 
 |     @onlyNativeDeviceTypes | 
 |     @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) | 
 |     def test_view_tensor_vsplit(self, device, dtype): | 
 |         t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9) | 
 |         t_vsplit = torch.vsplit(t, 2) | 
 |         for t_vsplit_tensor in t_vsplit: | 
 |             self.assertTrue(self.is_view_of(t, t_vsplit_tensor)) | 
 |         t[2, 2, 2] = 7 | 
 |         self.assertEqual(t_vsplit[1][0, 2, 2], t[2, 2, 2]) | 
 |  | 
 |     @onlyNativeDeviceTypes | 
 |     @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) | 
 |     def test_view_tensor_dsplit(self, device, dtype): | 
 |         t = make_tensor((4, 4, 4), dtype=dtype, device=device, low=-9, high=9) | 
 |         t_dsplit = torch.dsplit(t, 2) | 
 |         for t_dsplit_tensor in t_dsplit: | 
 |             self.assertTrue(self.is_view_of(t, t_dsplit_tensor)) | 
 |         t[2, 2, 2] = 7 | 
 |         self.assertEqual(t_dsplit[1][2, 2, 0], t[2, 2, 2]) | 
 |  | 
 |     @onlyNativeDeviceTypes | 
 |     @dtypes(*all_types_and(torch.half, torch.bfloat16)) | 
 |     def test_imag_noncomplex(self, device, dtype): | 
 |         t = torch.ones((5, 5), dtype=dtype, device=device) | 
 |  | 
 |         with self.assertRaises(RuntimeError): | 
 |             torch.imag(t) | 
 |  | 
 |     @onlyNativeDeviceTypes | 
 |     @dtypes(*complex_types()) | 
 |     def test_real_imag_view(self, device, dtype): | 
 |         def compare_with_numpy(contiguous_input=True): | 
 |             t = torch.randn(3, 3, dtype=dtype, device=device) | 
 |             if not contiguous_input: | 
 |                 u = t.T | 
 |             else: | 
 |                 u = t | 
 |  | 
 |             re = u.real | 
 |             exp = torch.from_numpy(u.cpu().numpy().real).to(device=device) | 
 |             self.assertEqual(re, exp) | 
 |             # for the case of contiguous_input, t=u | 
 |             # for the case of non contiguous_input, the base still remains | 
 |             # t since we are performing a view operation to make the input non-contiguous | 
 |             self.assertTrue(self.is_view_of(t, re)) | 
 |  | 
 |             im = u.imag | 
 |             exp = torch.from_numpy(u.cpu().numpy().imag).to(device=device) | 
 |             self.assertEqual(im, exp) | 
 |             self.assertTrue(self.is_view_of(t, im)) | 
 |  | 
 |         compare_with_numpy() | 
 |         compare_with_numpy(contiguous_input=False) | 
 |  | 
 |         # ensure storage offset is being correctly set | 
 |         a = torch.randn(10, dtype=dtype) | 
 |         self.assertEqual(a[5:].real, a.real[5:]) | 
 |         self.assertEqual(a[5:].imag, a.imag[5:]) | 
 |  | 
 |     @onlyNativeDeviceTypes | 
 |     @dtypes(*complex_types()) | 
 |     def test_conj_imag_view(self, device, dtype) -> None: | 
 |         t = _make_tensor((4, 5,), dtype, device) | 
 |         t_numpy_conj = torch.from_numpy(t.cpu().numpy().conj()).to(device=device) | 
 |         v = t.conj() | 
 |         self.assertTrue(self.is_view_of(t, v)) | 
 |         self.assertEqual(v, t_numpy_conj) | 
 |  | 
 |         if (t.is_complex()): | 
 |             v_imag = v.imag | 
 |             self.assertTrue(self.is_view_of(t, v_imag)) | 
 |             self.assertEqual(v_imag, t_numpy_conj.imag) | 
 |             self.assertTrue(v_imag.is_neg()) | 
 |  | 
 |     @onlyNativeDeviceTypes | 
 |     def test_conj_view_with_shared_memory(self, device) -> None: | 
 |         a = _make_tensor((4, 5,), torch.cfloat, device) | 
 |         b = a.conj() | 
 |         c = a.conj() | 
 |  | 
 |         self.assertEqual(torch.add(a, b), a.add_(b)) | 
 |         self.assertEqual(torch.add(b, c), torch.add(b, c, out=a)) | 
 |         self.assertEqual(torch.add(b, c), b.add_(c)) | 
 |  | 
 |     @onlyNativeDeviceTypes | 
 |     @dtypes(*product(complex_types(), all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool))) | 
 |     @suppress_warnings | 
 |     def test_set_real_imag(self, device, dtypes): | 
 |         x = torch.randn(10, dtype=dtypes[0], device=device) | 
 |  | 
 |         new_real = _make_tensor((10,), dtypes[1], device) | 
 |         new_imag = _make_tensor((10,), dtypes[1], device) | 
 |  | 
 |         x.real = new_real | 
 |         x.imag = new_imag | 
 |  | 
 |         if dtypes[1].is_complex: | 
 |             self.assertEqual(x.real, new_real.real, exact_dtype=False) | 
 |             self.assertEqual(x.imag, new_imag.real, exact_dtype=False) | 
 |  | 
 |         else: | 
 |             self.assertEqual(x.real, new_real, exact_dtype=False) | 
 |             self.assertEqual(x.imag, new_imag, exact_dtype=False) | 
 |  | 
 |     def test_diagonal_view(self, device) -> None: | 
 |         t = torch.ones((5, 5), device=device) | 
 |         v = torch.diagonal(t) | 
 |         self.assertTrue(self.is_view_of(t, v)) | 
 |  | 
 |         v[0] = 0 | 
 |         self.assertEqual(t[0, 0], v[0]) | 
 |  | 
 |         t = torch.ones((3, 3, 3), device=device) | 
 |         v = torch.diagonal(t, offset=1, dim1=1, dim2=2) | 
 |         self.assertTrue(self.is_view_of(t, v)) | 
 |  | 
 |         v[0, 0] = 0 | 
 |         self.assertEqual(t[0, 0, 1], v[0, 0]) | 
 |  | 
 |     def test_select_view(self, device) -> None: | 
 |         t = torch.ones((5, 5), device=device) | 
 |         v = t.select(0, 2) | 
 |         self.assertTrue(self.is_view_of(t, v)) | 
 |  | 
 |         v[0] = 0 | 
 |         self.assertEqual(t[2, 0], v[0]) | 
 |  | 
 |     # Lazy hasn't implemented unbind yet. | 
 |     @onlyNativeDeviceTypes | 
 |     def test_unbind_view(self, device) -> None: | 
 |         t = torch.zeros((5, 5), device=device) | 
 |         tup = torch.unbind(t) | 
 |  | 
 |         for idx, v in enumerate(tup): | 
 |             self.assertTrue(self.is_view_of(t, v)) | 
 |  | 
 |             v[0] = idx + 1 | 
 |             self.assertEqual(t[idx, 0], v[0]) | 
 |  | 
 |     # TODO: opinfo this or move to unbind's test suite | 
 |     def test_unbind(self): | 
 |         stacked = torch.randn(3, 10, 10, requires_grad=True) | 
 |         x, y, z = stacked.unbind() | 
 |         grad = torch.randn(3, 10, 10) | 
 |         torch.autograd.backward([x, y, z], grad.unbind()) | 
 |         self.assertEqual(stacked.grad, grad) | 
 |         # check that it works with only one gradient provided (#9977) | 
 |         for i in range(3): | 
 |             stacked = torch.randn(3, 10, 10, requires_grad=True) | 
 |             outs = stacked.unbind() | 
 |             gi = grad.unbind()[i] | 
 |             g, = torch.autograd.grad(outs[i], stacked, gi) | 
 |             g_expected = torch.stack([gi if j == i else torch.zeros_like(gi) | 
 |                                       for j in range(3)], dim=0) | 
 |             self.assertEqual(g, g_expected) | 
 |         # Check with gradcheck | 
 |         stacked = torch.randn(3, 10, 10, dtype=torch.double, requires_grad=True) | 
 |         gradcheck(lambda x: x.unbind(), (stacked,), check_forward_ad=True) | 
 |  | 
 |     # TODO: Fix this test for LTC. There is an interaction with dynamic shapes here that is broken, | 
 |     # causing asserts to trigger. | 
 |     @onlyNativeDeviceTypes | 
 |     def test_expand_view(self, device) -> None: | 
 |         t = torch.ones((5, 1), device=device) | 
 |         v = t.expand(5, 5) | 
 |         self.assertTrue(self.is_view_of(t, v)) | 
 |  | 
 |         v[2, 2] = 0 | 
 |         self.assertEqual(t[2, 0], v[2, 2]) | 
 |  | 
 |     def test_expand_as_view(self, device): | 
 |         t = torch.ones((5, 1), device=device) | 
 |         e = torch.empty((5, 5), device=device) | 
 |         v = t.expand_as(e) | 
 |         self.assertTrue(self.is_view_of(t, v)) | 
 |  | 
 |         v[2, 2] = 0 | 
 |         self.assertEqual(t[2, 0], v[2, 2]) | 
 |  | 
 |     def test_narrow_view(self, device): | 
 |         t = torch.ones((5, 5), device=device) | 
 |         v = torch.narrow(t, 1, 2, 2) | 
 |         self.assertTrue(self.is_view_of(t, v)) | 
 |  | 
 |         v[0, 0] = 0 | 
 |         self.assertEqual(t[0, 2], v[0, 0]) | 
 |  | 
 |     def test_permute_view(self, device) -> None: | 
 |         t = torch.ones((5, 5), device=device) | 
 |         v = t.permute(1, 0) | 
 |         self.assertTrue(self.is_view_of(t, v)) | 
 |  | 
 |         v[0, 1] = 0 | 
 |         self.assertEqual(t[1, 0], v[0, 1]) | 
 |  | 
 |     def test_transpose_view(self, device): | 
 |         for fn in (torch.swapdims, torch.swapaxes, torch.transpose): | 
 |             t = torch.ones((5, 5), device=device) | 
 |             v = fn(t, 0, 1) | 
 |             self.assertTrue(self.is_view_of(t, v)) | 
 |  | 
 |             v[0, 1] = 0 | 
 |             self.assertEqual(t[1, 0], v[0, 1]) | 
 |  | 
 |     def test_transpose_inplace_view(self, device): | 
 |         t = torch.ones(5, 5, device=device) | 
 |         v = t.view_as(t) | 
 |         v = v.swapdims_(0, 1) | 
 |         self.assertTrue(self.is_view_of(t, v)) | 
 |         v[0, 1] = 0 | 
 |         self.assertEqual(t[1, 0], v[0, 1]) | 
 |  | 
 |         t = torch.ones(5, 5, device=device) | 
 |         v = t.view_as(t) | 
 |         v = v.swapaxes_(0, 1) | 
 |         self.assertTrue(self.is_view_of(t, v)) | 
 |         v[0, 1] = 0 | 
 |         self.assertEqual(t[1, 0], v[0, 1]) | 
 |  | 
 |         t = torch.ones(5, 5, device=device) | 
 |         v = t.view_as(t) | 
 |         v = v.transpose_(0, 1) | 
 |         self.assertTrue(self.is_view_of(t, v)) | 
 |         v[0, 1] = 0 | 
 |         self.assertEqual(t[1, 0], v[0, 1]) | 
 |  | 
 |     def test_t_view(self, device): | 
 |         t = torch.ones((5, 5), device=device) | 
 |         v = t.t() | 
 |         self.assertTrue(self.is_view_of(t, v)) | 
 |  | 
 |         v[0, 1] = 0 | 
 |         self.assertEqual(t[1, 0], v[0, 1]) | 
 |  | 
 |     def test_t_inplace_view(self, device): | 
 |         t = torch.ones(5, 5, device=device) | 
 |         v = t.view_as(t) | 
 |         v = v.t_() | 
 |         self.assertTrue(self.is_view_of(t, v)) | 
 |         v[0, 1] = 0 | 
 |         self.assertEqual(t[1, 0], v[0, 1]) | 
 |  | 
 |     def test_T_view(self, device): | 
 |         for op in ("T", "H", "mT", "mH"): | 
 |             t = torch.ones((5, 5), device=device) | 
 |             v = getattr(t, op) | 
 |             self.assertTrue(self.is_view_of(t, v)) | 
 |  | 
 |             v[0, 1] = 0 | 
 |             self.assertEqual(t[1, 0], v[0, 1]) | 
 |  | 
 |     def test_unfold_view(self, device): | 
 |         t = torch.ones(10, device=device) | 
 |         v = t.unfold(0, 3, 2) | 
 |         self.assertTrue(self.is_view_of(t, v)) | 
 |  | 
 |         v[1, 0] = 0 | 
 |         self.assertEqual(t[2], v[1, 0]) | 
 |  | 
 |     def test_squeeze_view(self, device): | 
 |         t = torch.ones(5, 1, 5, device=device) | 
 |         v = torch.squeeze(t) | 
 |         self.assertTrue(self.is_view_of(t, v)) | 
 |         v[0, 1] = 0 | 
 |         self.assertEqual(t, v._base) | 
 |  | 
 |     def test_squeeze_inplace_view(self, device): | 
 |         t = torch.ones(5, 5, device=device) | 
 |         v = t.view_as(t) | 
 |         v = v.squeeze_() | 
 |         self.assertTrue(self.is_view_of(t, v)) | 
 |         v[0, 1] = 0 | 
 |         self.assertEqual(t, v._base) | 
 |  | 
 |     def test_unsqueeze_view(self, device): | 
 |         t = torch.ones(5, 5, device=device) | 
 |         v = torch.unsqueeze(t, 1) | 
 |         self.assertTrue(self.is_view_of(t, v)) | 
 |  | 
 |         v[0, 0, 1] = 0 | 
 |         self.assertEqual(t[0, 1], v[0, 0, 1]) | 
 |  | 
 |     def test_unsqueeze_inplace_view(self, device): | 
 |         t = torch.ones(5, 5, device=device) | 
 |         v = t.view_as(t) | 
 |         v = v.unsqueeze_(1) | 
 |         self.assertTrue(self.is_view_of(t, v)) | 
 |         v[0, 0, 1] = 0 | 
 |         self.assertEqual(t[0, 1], v[0, 0, 1]) | 
 |  | 
 |     def test_as_strided_view(self, device): | 
 |         t = torch.ones(5, 5, device=device) | 
 |         v = torch.as_strided(t, (25,), (1,)) | 
 |         self.assertTrue(self.is_view_of(t, v)) | 
 |  | 
 |         v[6] = 0 | 
 |         self.assertEqual(t[1, 1], v[6]) | 
 |  | 
 |     def test_as_strided_inplace_view(self, device): | 
 |         t = torch.ones(5, 5, device=device) | 
 |         v = t.view_as(t) | 
 |         v = v.as_strided_((25,), (1,)) | 
 |         self.assertTrue(self.is_view_of(t, v)) | 
 |         v[6] = 0 | 
 |         self.assertEqual(t[1, 1], v[6]) | 
 |  | 
 |     def test_as_strided_gradients(self): | 
 |         def test(x, prepro_fn, size, strides, offset=None): | 
 |             x = x.to(torch.double).detach().requires_grad_() | 
 |  | 
 |             # Check that forward will **not** resize storage because it may | 
 |             # cause NaN in output and fail numerical Jacobian check consequently | 
 |             with torch.no_grad(): | 
 |                 y = prepro_fn(x) if prepro_fn is not None else x | 
 |                 max_offset = sum((si - 1) * st for si, st in zip(size, strides)) | 
 |                 max_offset += offset if offset is not None else y.storage_offset() | 
 |                 assert max_offset < len(y.storage()), "test case resizes storage" | 
 |  | 
 |             def closure(x): | 
 |                 if prepro_fn is not None: | 
 |                     x = prepro_fn(x) | 
 |                 return x.as_strided(size, strides, offset) | 
 |  | 
 |             gradcheck(closure, [x], check_forward_ad=True) | 
 |             gradgradcheck(closure, [x]) | 
 |  | 
 |         # test | 
 |         test(torch.arange(0, 25), lambda x: x.view(5, 5), [3, 3], [6, 2], 2) | 
 |  | 
 |         # test crazy stride at dim with size 1 case | 
 |         test(torch.randn(12), None, [1, 2, 1, 5], [0, 5, 100, 1], 2) | 
 |  | 
 |         # test expand case | 
 |         test(torch.randn(5), None, [3, 3, 3], [0, 1, 0], 2) | 
 |         test(torch.randn(5), None, [3, 3, 3], [0, 0, 0], 4) | 
 |         test(torch.randn(5), lambda x: x.expand(5, 5), [5, 5], [0, 1], 0) | 
 |  | 
 |         # test non-expand overlapping case | 
 |         test(torch.randn(35), None, [6, 6], [5, 1], 2) | 
 |         test(torch.randn(15), None, [3, 2], [3, 6], 2) | 
 |  | 
 |         # test transpose case | 
 |         test(torch.randn(3, 4), None, [4, 3], [1, 4]) | 
 |  | 
 |         # test "getting things outside the input" case | 
 |         x = torch.randn(6, 2) | 
 |         test(x[3:], None, [3, 2], [2, 1], 0)  # should be all zeros | 
 |         self.assertEqual(x[3:].as_strided([3, 2], [2, 1], 0), x[:3]) | 
 |  | 
 |         # test select on expanded input case | 
 |         test(torch.randn(2, 3), lambda x: x.expand(10, 2, 3), [2, 3], [3, 1], 0) | 
 |  | 
 |     def test_view_view(self, device): | 
 |         t = torch.ones(5, 5, device=device) | 
 |         v = t.view(25) | 
 |         self.assertTrue(self.is_view_of(t, v)) | 
 |  | 
 |         v[6] = 0 | 
 |         self.assertEqual(t[1, 1], v[6]) | 
 |  | 
 |     def test_view_as_view(self, device): | 
 |         t = torch.ones(5, 5, device=device) | 
 |         e = torch.empty((25,)) | 
 |         v = t.view_as(e) | 
 |         self.assertTrue(self.is_view_of(t, v)) | 
 |  | 
 |         v[6] = 0 | 
 |         self.assertEqual(t[1, 1], v[6]) | 
 |  | 
 |     def test_contiguous_self(self, device): | 
 |         t = torch.ones(5, 5, device=device) | 
 |         s = t.contiguous() | 
 |         self.assertTrue(s is t) | 
 |  | 
 |     @skipMeta | 
 |     # self.is_view_of reports false positives for lazy | 
 |     @onlyNativeDeviceTypes | 
 |     def test_contiguous_nonview(self, device): | 
 |         t = torch.ones(5, 5, device=device) | 
 |         nv = t.t().contiguous() | 
 |         self.assertTrue(not self.is_view_of(t, nv)) | 
 |  | 
 |         nv[0, 0] = 0 | 
 |         self.assertNotEqual(t[0, 0], nv[0, 0]) | 
 |  | 
 |     def test_reshape_view(self, device): | 
 |         t = torch.ones(5, 5, device=device) | 
 |         v = torch.reshape(t, (25,)) | 
 |         self.assertTrue(self.is_view_of(t, v)) | 
 |  | 
 |         v[6] = 0 | 
 |         self.assertEqual(t[1, 1], v[6]) | 
 |  | 
 |     def test_reshape_as_view(self, device): | 
 |         t = torch.ones(5, 5, device=device) | 
 |         e = torch.empty((25,), device=device) | 
 |         v = t.reshape_as(e) | 
 |         self.assertTrue(self.is_view_of(t, v)) | 
 |  | 
 |         v[6] = 0 | 
 |         self.assertEqual(t[1, 1], v[6]) | 
 |  | 
 |     @skipMeta | 
 |     # self.is_view_of reports false positives for lazy | 
 |     @onlyNativeDeviceTypes | 
 |     def test_reshape_nonview(self, device): | 
 |         t = torch.ones(5, 5, device=device) | 
 |         nv = torch.reshape(t.t(), (25,)) | 
 |         self.assertTrue(not self.is_view_of(t, nv)) | 
 |  | 
 |         nv[6] = 0 | 
 |         self.assertNotEqual(t[1, 1], nv[6]) | 
 |  | 
 |     # This test use as_strided to construct a tensor with overlapping memory, | 
 |     # which is not handled by the functionalization pass. | 
 |     @onlyNativeDeviceTypes | 
 |     def test_flatten_view(self, device): | 
 |         def test_writes_propagate(t, v): | 
 |             idx_t = (0,) * t.ndim | 
 |             idx_v = (0,) * v.ndim | 
 |             v[idx_v] = 0 | 
 |             self.assertEqual(t[idx_t], v[idx_v]) | 
 |  | 
 |         t = torch.ones(1, 2, 3, 4, device=device) | 
 |         v = t.flatten() | 
 |         self.assertTrue(self.is_view_of(t, v)) | 
 |         test_writes_propagate(t, v) | 
 |  | 
 |         # zero-dimensional tensor | 
 |         t = torch.tensor(1, device=device) | 
 |         v = t.flatten() | 
 |         test_writes_propagate(t, v) | 
 |         self.assertTrue(self.is_view_of(t, v)) | 
 |  | 
 |         t = torch.ones(1, 2, 3, 4, device=device).transpose(2, 3) | 
 |         v = t.flatten(0, 1) | 
 |         test_writes_propagate(t, v) | 
 |         self.assertTrue(self.is_view_of_same_base(t, v)) | 
 |  | 
 |         # stride[i] = stride[i + 1] * size[i + 1] is satisfied for 3 groups: | 
 |         t = torch.ones(720, device=device) \ | 
 |             .as_strided((2, 3, 2, 3, 5, 4), (6, 2, 15, 5, 1, 0)) | 
 |         #               [--1--|---2---|-3-] [--1--|----2---|-3-] | 
 |         v1 = t.flatten(0, 1) | 
 |         v2 = v1.flatten(1, 3) | 
 |         v3 = v2.flatten(2, 2) | 
 |         test_writes_propagate(t, v1) | 
 |         self.assertTrue(self.is_view_of_same_base(t, v1)) | 
 |         test_writes_propagate(t, v2) | 
 |         self.assertTrue(self.is_view_of_same_base(t, v2)) | 
 |         test_writes_propagate(t, v3) | 
 |         self.assertTrue(self.is_view_of_same_base(t, v3)) | 
 |  | 
 |     @onlyNativeDeviceTypes | 
 |     def test_flatten_nonview(self, device): | 
 |         def assert_is_nonview(t, nv): | 
 |             idx_t = (0,) * t.ndim | 
 |             idx_nv = (0,) * nv.ndim | 
 |             self.assertTrue(not nv._is_view()) | 
 |             nv[idx_nv] = 0 | 
 |             if device != "meta": | 
 |                 self.assertNotEqual(t[idx_t], nv[idx_nv]) | 
 |         t = torch.ones(2, 3, 2, 3, device=device).transpose(2, 3) | 
 |         nv = t.flatten(1, 3) | 
 |         assert_is_nonview(t, nv) | 
 |  | 
 |         t = torch.ones(2, 2, device=device).T | 
 |         nv = t.flatten() | 
 |         assert_is_nonview(t, nv) | 
 |  | 
 |         # flatten returns the original object if start_dim=end_dim | 
 |         t = t = torch.ones(2, 2, device=device) | 
 |         nv = t.flatten(1, 1) | 
 |         self.assertTrue(t is nv) | 
 |  | 
 |     def test_basic_indexing_slice_view(self, device): | 
 |         t = torch.ones(5, 5, device=device) | 
 |         v = t[:2, :3] | 
 |         self.assertTrue(self.is_view_of(t, v)) | 
 |  | 
 |         v[0, 0] = 0 | 
 |         self.assertEqual(t[0, 0], v[0, 0]) | 
 |  | 
 |     def test_basic_indexing_ellipses_view(self, device): | 
 |         t = torch.ones(5, 5, device=device) | 
 |         v = t[..., :2] | 
 |         self.assertTrue(self.is_view_of(t, v)) | 
 |  | 
 |         v[0, 0] = 0 | 
 |         self.assertEqual(t[0, 0], v[0, 0]) | 
 |  | 
 |     def test_basic_indexing_newaxis_view(self, device): | 
 |         t = torch.ones(5, 5, device=device) | 
 |         v = t[None, :2, 3] | 
 |         self.assertTrue(self.is_view_of(t, v)) | 
 |  | 
 |         v[0, 0] = 0 | 
 |         self.assertEqual(t[0, 3], v[0, 0]) | 
 |  | 
 |     def test_advanced_indexing_nonview(self, device): | 
 |         t = torch.ones(3, 3, device=device) | 
 |         rows = torch.tensor([[0, 0], [2, 2]], device=device) | 
 |         cols = torch.tensor([[0, 1], [2, 2]], device=device) | 
 |         nv = t[rows, cols] | 
 |         self.assertTrue(not self.is_view_of(t, nv)) | 
 |  | 
 |         nv[1, 1] = 0 | 
 |         self.assertNotEqual(t[2, 2], nv[1, 1]) | 
 |  | 
 |     @unittest.skipIf(IS_FBCODE, "TorchScript backend not yet supported in FBCODE/OVRSOURCE builds") | 
 |     def test_advanced_indexing_assignment(self, device): | 
 |         t = torch.ones(3, 3, device=device) | 
 |         rows = torch.tensor([[0, 0], [2, 2]], device=device) | 
 |         cols = torch.tensor([[0, 1], [2, 2]], device=device) | 
 |         t[rows, cols] = 0 | 
 |         self.assertEqual(t[2, 2], 0) | 
 |  | 
 |     @unittest.skip("See https://github.com/pytorch/pytorch/pull/32720") | 
 |     def test_chunk_view(self, device): | 
 |         t = torch.zeros(3, 3, device=device) | 
 |         l = torch.chunk(t, 3) | 
 |  | 
 |         for idx, v in enumerate(l): | 
 |             self.assertTrue(self.is_view_of(t, v)) | 
 |  | 
 |             v[0, 0] = idx + 1 | 
 |             self.assertEqual(t[idx, 0], v[0, 0]) | 
 |  | 
 |     @unittest.skip("See https://github.com/pytorch/pytorch/pull/32720") | 
 |     def test_split_view(self, device): | 
 |         t = torch.zeros(3, 3, device=device) | 
 |         l = torch.split(t, [1, 1, 1]) | 
 |  | 
 |         for idx, v in enumerate(l): | 
 |             self.assertTrue(self.is_view_of(t, v)) | 
 |  | 
 |             v[0, 0] = idx + 1 | 
 |             self.assertEqual(t[idx, 0], v[0, 0]) | 
 |  | 
 |     def test_movedim_view(self, device): | 
 |         def run_test(device, op): | 
 |             t = torch.zeros(3, 3, device=device) | 
 |             out = op(t) | 
 |  | 
 |             self.assertTrue(self.is_view_of(t, out)) | 
 |  | 
 |             # Randomly change values in output | 
 |             # and verify that original is changed | 
 |             # as well. | 
 |             for _ in range(3): | 
 |                 idx_1, idx_2 = random.randint(0, 2), random.randint(0, 2) | 
 |                 out[idx_1, idx_2] = random.random() | 
 |                 self.assertEqual(t[idx_2, idx_1], out[idx_1, idx_2]) | 
 |  | 
 |         for fn in [torch.movedim, torch.moveaxis]: | 
 |             op = partial(fn, source=(0, 1), destination=(1, 0)) | 
 |             run_test(device, op) | 
 |  | 
 |             op = partial(fn, source=0, destination=1) | 
 |             run_test(device, op) | 
 |  | 
 |     # Testing that the generated view_copy kernel and its derivative are implemented correctly | 
 |     def test_view_copy(self, device): | 
 |         a = torch.randn(4, device=device, requires_grad=True) | 
 |         a_ref = a.clone().detach().requires_grad_() | 
 |         a_view = a_ref.view(2, 2) | 
 |         a_view_copy = torch.view_copy(a, (2, 2)) | 
 |  | 
 |         # view_copy ops don't preserve view relationship | 
 |         self.assertTrue(self.is_view_of(a_ref, a_view)) | 
 |         self.assertFalse(self.is_view_of(a, a_view_copy)) | 
 |  | 
 |         a_view_copy.sum().backward() | 
 |         a_view.sum().backward() | 
 |  | 
 |         # forward and backward give the same shape + result | 
 |         self.assertEqual(a_view_copy, a_view) | 
 |         self.assertEqual(a.grad, a_ref.grad) | 
 |  | 
 |     # Testing that the output of a view_copy kernel (by default) is contiguous. | 
 |     def test_view_copy_output_contiguous(self, device): | 
 |         a = torch.randn(4, 4, 4, 4, device=device).to(memory_format=torch.channels_last) | 
 |         b = torch.ops.aten.slice_copy(a, 0, 0, 2) | 
 |         self.assertTrue(b.is_contiguous()) | 
 |  | 
 |     def test_view_copy_out(self, device): | 
 |         a = torch.randn(2, 2, device=device) | 
 |         out = torch.empty(2, device=device) | 
 |  | 
 |         torch.diagonal_copy(a, out=out) | 
 |         expected = torch.diagonal_copy(a) | 
 |  | 
 |         self.assertEqual(expected, out) | 
 |  | 
 |         a = torch.randn(4, device=device) | 
 |         out1 = torch.empty(2, device=device) | 
 |         out2 = torch.empty(2, device=device) | 
 |  | 
 |         torch.split_copy(a, 2, out=(out1, out2)) | 
 |         expected1, expected2 = torch.split_copy(a, 2) | 
 |  | 
 |         self.assertEqual(expected1, out1) | 
 |         self.assertEqual(expected2, out2) | 
 |  | 
 | class TestOldViewOps(TestCase): | 
 |     def test_ravel(self, device): | 
 |  | 
 |         def _test_ravel(tensors, size, nc=False): | 
 |             for src in tensors: | 
 |                 # Continuous Tensor -> View | 
 |                 flat = src.ravel() | 
 |                 self.assertEqual(flat.shape, torch.Size([size])) | 
 |                 self.assertEqual(src.view(-1), flat) | 
 |                 self.assertIs(flat._base, src) | 
 |                 self.assertTrue(flat.is_contiguous()) | 
 |  | 
 |                 # Non-continuous Tensor -> Copy | 
 |                 if nc: | 
 |                     nc_src = src.t() | 
 |                     nc_flat = nc_src.ravel() | 
 |                     self.assertEqual(nc_flat.shape, torch.Size([size])) | 
 |                     self.assertEqual(nc_src.contiguous().view(-1), nc_flat) | 
 |                     self.assertIsNot(nc_flat._base, src) | 
 |                     self.assertTrue(nc_flat.is_contiguous()) | 
 |  | 
 |         # Test that flatten returns 1-dim tensor when given a 0-dim tensor | 
 |         zero_dim_tensor = torch.tensor(123, device=device) | 
 |         flat0 = zero_dim_tensor.ravel() | 
 |         one_dim_tensor = torch.tensor([123], device=device) | 
 |         flat1 = zero_dim_tensor.ravel() | 
 |         nc_ones_tensor = torch.ones(10, device=device)[::2] | 
 |         flat2 = nc_ones_tensor.ravel() | 
 |  | 
 |         self.assertEqual(zero_dim_tensor.shape, torch.Size([])) | 
 |         self.assertEqual(flat0.shape, torch.Size([1])) | 
 |         self.assertEqual(one_dim_tensor.shape, torch.Size([1])) | 
 |         self.assertEqual(flat1.shape, torch.Size([1])) | 
 |         self.assertEqual(nc_ones_tensor.shape, torch.Size([5])) | 
 |         self.assertEqual(flat2.shape, torch.Size([5])) | 
 |         self.assertEqual(flat0, one_dim_tensor) | 
 |         self.assertEqual(flat0, flat1) | 
 |         self.assertEqual(flat0.shape, flat1.shape) | 
 |         self.assertTrue(flat0.is_contiguous()) | 
 |         self.assertTrue(flat1.is_contiguous()) | 
 |         self.assertTrue(flat2.is_contiguous()) | 
 |  | 
 |         # Test both float tensor and quantized tensor | 
 |         tensors = [torch.randn(5, 5, 5, 5, device=device), | 
 |                    torch._empty_affine_quantized([5, 5, 5, 5], | 
 |                                                  scale=2, | 
 |                                                  zero_point=3, | 
 |                                                  dtype=torch.quint8, | 
 |                                                  device=device)] | 
 |         _test_ravel(tensors, 625) | 
 |  | 
 |         tensors = [torch.randn(0, 2, 3, device=device), | 
 |                    torch.randn(3, 0, 2, device=device), | 
 |                    torch._empty_affine_quantized([0, 2, 3], | 
 |                                                  scale=2, | 
 |                                                  zero_point=3, | 
 |                                                  dtype=torch.quint8, | 
 |                                                  device=device), | 
 |                    torch._empty_affine_quantized([3, 0, 2], | 
 |                                                  scale=2, | 
 |                                                  zero_point=3, | 
 |                                                  dtype=torch.quint8, | 
 |                                                  device=device)] | 
 |         _test_ravel(tensors, 0) | 
 |  | 
 |         tensors = [torch.randn(5, 5, device=device), | 
 |                    torch._empty_affine_quantized([5, 5], | 
 |                                                  scale=2, | 
 |                                                  zero_point=3, | 
 |                                                  dtype=torch.quint8, | 
 |                                                  device=device)] | 
 |         _test_ravel(tensors, 25, True) | 
 |  | 
 |     # TODO: this should be refactored into the view ops test suite | 
 |     def test_empty_reshape(self, device): | 
 |         x = torch.randn(0, 6, device=device) | 
 |         self.assertEqual((1, 0, 6, 1, 1), x.reshape(1, 0, 6, 1, 1).shape) | 
 |         # should be viewable -- i.e. data_ptr is the same. | 
 |         self.assertEqual(x.data_ptr(), x.reshape(1, 0, 6, 1, 1).data_ptr()) | 
 |  | 
 |         # match NumPy semantics -- don't infer the size of dimension with a degree of freedom | 
 |         self.assertRaises(RuntimeError, lambda: x.reshape(0, -1)) | 
 |  | 
 |     @skipIfTorchDynamo("TorchDynamo fails with unknown reason") | 
 |     def test_expand(self, device): | 
 |         tensor = torch.rand(1, 8, 1, device=device) | 
 |         tensor2 = torch.rand(5, device=device) | 
 |         template = torch.rand(4, 8, 5, device=device) | 
 |         target = template.size() | 
 |         self.assertEqual(tensor.expand_as(template).size(), target) | 
 |         self.assertEqual(tensor.expand(4, 8, 5).size(), target) | 
 |         self.assertEqual(tensor.expand(target).size(), target) | 
 |         self.assertEqual(tensor2.expand_as(template).size(), target) | 
 |         self.assertEqual(tensor2.expand(4, 8, 5).size(), target) | 
 |         self.assertEqual(tensor2.expand(target).size(), target) | 
 |  | 
 |         # test double expand | 
 |         self.assertEqual(tensor2.expand(1, 5).expand(2, 2, 5), tensor2.repeat(2, 2, 1)) | 
 |  | 
 |         # test non-contiguous | 
 |         noncontig = torch.randn(5, 2, 1, 3, device=device)[:, 0] | 
 |         self.assertFalse(noncontig.is_contiguous()) | 
 |         self.assertEqual(noncontig.expand(2, 5, 4, 3), noncontig.contiguous().repeat(2, 1, 4, 1)) | 
 |  | 
 |         # make sure it's compatible with unsqueeze | 
 |         expanded = tensor2.expand(1, 1, 5) | 
 |         unsqueezed = tensor2.unsqueeze(0).unsqueeze(1) | 
 |         self.assertEqual(expanded, unsqueezed) | 
 |         self.assertEqual(expanded.stride(), unsqueezed.stride()) | 
 |  | 
 |         # test -1 as target size | 
 |         self.assertEqual(tensor.expand(4, -1, 5), tensor.expand(4, 8, 5)) | 
 |         self.assertRaises(RuntimeError, lambda: tensor2.expand(-1, -1)) | 
 |  | 
 |         # test expanding empty to empty | 
 |         self.assertEqual(torch.zeros(0, device=device).expand((0,)), torch.zeros(0, device=device)) | 
 |  | 
 |     # TODO: this should be refactored into the view ops test suite | 
 |     def test_view_empty(self, device): | 
 |         x = torch.randn(0, 6, device=device) | 
 |         self.assertEqual((1, 0, 6, 1, 1), x.view(1, 0, 6, 1, 1).shape) | 
 |  | 
 |     # TODO: this should be refactored into the view ops test suite | 
 |     @onlyNativeDeviceTypes | 
 |     def test_reshape(self, device): | 
 |         x = torch.randn(3, 3, device=device) | 
 |         self.assertEqual(x.data_ptr(), x.reshape(-1).data_ptr()) | 
 |         self.assertEqual(x.data_ptr(), x.reshape(1, 9, 1).data_ptr()) | 
 |         self.assertEqual(torch.reshape(x, (9,)), x.reshape(9)) | 
 |         self.assertRaises(RuntimeError, lambda: x.reshape(-1, -1)) | 
 |  | 
 |         y = torch.randn(4, 4, 4, device=device)[:, 0, :] | 
 |         # .data_ptr() on meta tensors is always 0 so they are equal regardless of the reshape | 
 |         if device != "meta": | 
 |             self.assertNotEqual(y.data_ptr(), y.reshape(-1).data_ptr()) | 
 |         self.assertEqual(y.contiguous().view(-1), y.reshape(-1)) | 
 |         self.assertEqual(y.reshape(2, 2, 4).data_ptr(), y.data_ptr()) | 
 |  | 
 |         s = torch.randn((), device=device) | 
 |         self.assertEqual(s.data_ptr(), s.reshape(()).data_ptr()) | 
 |         self.assertEqual(s.reshape(-1).shape, (1,)) | 
 |         self.assertRaises(RuntimeError, lambda: s.reshape(2)) | 
 |  | 
 |         empty = torch.tensor([], device=device) | 
 |         self.assertEqual(empty, empty.reshape(-1)) | 
 |         self.assertEqual(empty, empty.reshape([0])) | 
 |         # TODO: fix these once we have multi-dimensional empty tensors | 
 |         self.assertEqual(empty.reshape([0, 1]).shape, (0, 1)) | 
 |         self.assertEqual(empty.reshape([1, -1]).shape, (1, 0)) | 
 |         self.assertRaises(RuntimeError, lambda: empty.reshape(1)) | 
 |  | 
 |         x = torch.randn(3, 3, device=device) | 
 |         self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(9)).data_ptr()) | 
 |         self.assertEqual(x.data_ptr(), x.reshape_as(torch.rand(1, 9, 1)).data_ptr()) | 
 |         self.assertRaises(RuntimeError, lambda: x.reshape_as(torch.rand(10, device=device))) | 
 |  | 
 |     def test_flatten(self, device): | 
 |         # Test that flatten returns 1-dim tensor when given a 0-dim tensor | 
 |         zero_dim_tensor = torch.tensor(123, device=device) | 
 |         flat0 = zero_dim_tensor.flatten() | 
 |         one_dim_tensor = torch.tensor([123], device=device) | 
 |         flat1 = zero_dim_tensor.flatten() | 
 |  | 
 |         self.assertEqual(zero_dim_tensor.shape, torch.Size([])) | 
 |         self.assertEqual(flat0.shape, torch.Size([1])) | 
 |         self.assertEqual(one_dim_tensor.shape, torch.Size([1])) | 
 |         self.assertEqual(flat1.shape, torch.Size([1])) | 
 |         self.assertEqual(flat0, one_dim_tensor) | 
 |         self.assertEqual(flat0, flat1) | 
 |         self.assertEqual(flat0.shape, flat1.shape) | 
 |  | 
 |         # Test both float tensor and quantized tensor | 
 |         tensors = [torch.randn(5, 5, 5, 5, device=device), | 
 |                    torch._empty_affine_quantized([5, 5, 5, 5], | 
 |                                                  scale=2, | 
 |                                                  zero_point=3, | 
 |                                                  dtype=torch.quint8, | 
 |                                                  device=device)] | 
 |         for src in tensors: | 
 |             flat = src.flatten(0, -1) | 
 |             self.assertEqual(flat.shape, torch.Size([625])) | 
 |             self.assertEqual(src.view(-1), flat.view(-1)) | 
 |  | 
 |             flat = src.flatten(0, 2) | 
 |             self.assertEqual(flat.shape, torch.Size([125, 5])) | 
 |             self.assertEqual(src.view(-1), flat.view(-1)) | 
 |  | 
 |             flat = src.flatten(0, 1) | 
 |             self.assertEqual(flat.shape, torch.Size([25, 5, 5])) | 
 |             self.assertEqual(src.view(-1), flat.view(-1)) | 
 |  | 
 |             flat = src.flatten(1, 2) | 
 |             self.assertEqual(flat.shape, torch.Size([5, 25, 5])) | 
 |             self.assertEqual(src.view(-1), flat.view(-1)) | 
 |  | 
 |             flat = src.flatten(2, 3) | 
 |             self.assertEqual(flat.shape, torch.Size([5, 5, 25])) | 
 |             self.assertEqual(src.view(-1), flat.view(-1)) | 
 |  | 
 |             flat = src.flatten(-2, -1) | 
 |             self.assertEqual(flat.shape, torch.Size([5, 5, 25])) | 
 |             self.assertEqual(src.view(-1), flat.view(-1)) | 
 |  | 
 |             flat = src.flatten(2, 2) | 
 |             self.assertEqual(flat, src) | 
 |  | 
 |             # out of bounds index | 
 |             with self.assertRaisesRegex(IndexError, 'Dimension out of range'): | 
 |                 src.flatten(5, 10) | 
 |  | 
 |             # invalid start and end | 
 |             with self.assertRaisesRegex(RuntimeError, 'start_dim cannot come after end_dim'): | 
 |                 src.flatten(2, 0) | 
 |  | 
 |     # TODO: update to work on CUDA, too | 
 |     @onlyCPU | 
 |     def test_narrow(self, device): | 
 |         x = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) | 
 |         self.assertEqual(x.narrow(0, 0, 1), torch.tensor([[0, 1, 2]])) | 
 |         self.assertEqual(x.narrow(0, 0, 2), torch.tensor([[0, 1, 2], [3, 4, 5]])) | 
 |         self.assertEqual(x.narrow(0, 1, 1), torch.tensor([[3, 4, 5]])) | 
 |         self.assertEqual(x.narrow(0, -1, 1), torch.tensor([[6, 7, 8]])) | 
 |         self.assertEqual(x.narrow(0, -2, 2), torch.tensor([[3, 4, 5], [6, 7, 8]])) | 
 |         self.assertEqual(x.narrow(0, -3, 3), torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]])) | 
 |         self.assertEqual(x.narrow(-1, -1, 1), torch.tensor([[2], [5], [8]])) | 
 |         self.assertEqual(x.narrow(-2, -1, 1), torch.tensor([[6, 7, 8]])) | 
 |  | 
 |     # TODO: update to work on CUDA, too | 
 |     @onlyCPU | 
 |     def test_narrow_tensor(self, device): | 
 |         x = torch.tensor([[0, 1, 2], [3, 4, 5], [6, 7, 8]]) | 
 |         self.assertEqual(x.narrow(0, torch.tensor(0), 1), torch.tensor([[0, 1, 2]])) | 
 |         with self.assertRaises(Exception): | 
 |             x.narrow(0, torch.tensor(0.), 1) | 
 |         with self.assertRaises(Exception): | 
 |             x.narrow(0, torch.tensor([0]), 1) | 
 |         with self.assertRaises(Exception): | 
 |             x.narrow(0, torch.tensor([0, 1]), 1) | 
 |  | 
 |     # TODO: make work on CUDA, too | 
 |     @onlyCPU | 
 |     def test_t(self, device): | 
 |         # Test 0D tensors | 
 |         x = torch.randn(()) | 
 |         self.assertEqual(x, x.t()) | 
 |         x = x.to_sparse() | 
 |         self.assertEqual(x, x.t()) | 
 |  | 
 |         # Test 1D tensors | 
 |         x = torch.arange(4) | 
 |         self.assertEqual(x, x.t()) | 
 |         x = x.to_sparse() | 
 |         self.assertEqual(x, x.t()) | 
 |  | 
 |         # Test 2D tensors | 
 |         x = torch.rand((2, 2)) | 
 |         self.assertEqual(x.t(), x.transpose(0, 1)) | 
 |         x = x.to_sparse() | 
 |         self.assertEqual(x.t(), x.transpose(0, 1)) | 
 |  | 
 |         # Test 3D tensor | 
 |         x = torch.rand((2, 2, 2)) | 
 |         with self.assertRaisesRegex(RuntimeError, 'expects a tensor with <= 2 dimensions, but self is 3D'): | 
 |             x.t() | 
 |         x = x.to_sparse() | 
 |         with self.assertRaisesRegex(RuntimeError, 'expects a tensor with <= 2 sparse and 0 dense dimensions'): | 
 |             x.t() | 
 |  | 
 |     @onlyCPU | 
 |     def test_split(self, device): | 
 |         tensor = torch.rand(7, 4) | 
 |         split_size = 3 | 
 |         dim = 0 | 
 |         target_sizes = ([3, 4], [3, 4], [1, 4]) | 
 |         splits = tensor.split(split_size, dim) | 
 |         start = 0 | 
 |         for target_size, split in zip(target_sizes, splits): | 
 |             self.assertEqual(split.size(), target_size) | 
 |             self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0) | 
 |             start = start + target_size[dim] | 
 |  | 
 |         # Variable sections split | 
 |         tensor = torch.randn(20, 10) | 
 |         dim = 0 | 
 |         split_sizes = [5, 5, 10] | 
 |         target_sizes = ([[5, 10], [5, 10], [10, 10]]) | 
 |         splits = tensor.split(split_sizes, dim) | 
 |         start = 0 | 
 |         for target_size, split in zip(target_sizes, splits): | 
 |             self.assertEqual(split.size(), target_size) | 
 |             self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0) | 
 |             start = start + target_size[dim] | 
 |  | 
 |         split_sizes = [2, 2, 6] | 
 |         target_sizes = ([20, 2], [20, 2], [20, 6]) | 
 |         dim = 1 | 
 |         splits = tensor.split(split_sizes, dim) | 
 |         start = 0 | 
 |         for target_size, split in zip(target_sizes, splits): | 
 |             self.assertEqual(split.size(), target_size) | 
 |             self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, atol=0, rtol=0) | 
 |             start = start + target_size[dim] | 
 |  | 
 |     @onlyCPU | 
 |     def test_chunk(self, device): | 
 |         tensor = torch.rand(4, 7) | 
 |         num_chunks = 3 | 
 |         dim = 1 | 
 |         target_sizes = ([4, 3], [4, 3], [4, 1]) | 
 |         splits = tensor.chunk(num_chunks, dim) | 
 |         start = 0 | 
 |         for target_size, split in zip(target_sizes, splits): | 
 |             self.assertEqual(split.size(), target_size) | 
 |             self.assertEqual(tensor.narrow(dim, start, target_size[dim]), split, | 
 |                              atol=0, rtol=0) | 
 |             start = start + target_size[dim] | 
 |  | 
 |         # Invalid chunk sizes | 
 |         error_regex = 'chunk expects.*greater than 0' | 
 |         with self.assertRaisesRegex(RuntimeError, error_regex): | 
 |             tensor.chunk(0) | 
 |         with self.assertRaisesRegex(RuntimeError, error_regex): | 
 |             tensor.chunk(-2) | 
 |  | 
 |     # TODO: make work on CUDA, too | 
 |     @skipIfTorchDynamo("TorchDynamo fails with unknown reason") | 
 |     @onlyCPU | 
 |     def test_unsqueeze(self, device) -> None: | 
 |         x = torch.randn(2, 3, 4) | 
 |         y = x.unsqueeze(1) | 
 |         self.assertEqual(y, x.view(2, 1, 3, 4)) | 
 |         y = x.clone().unsqueeze_(2) | 
 |         self.assertEqual(y, x.view(2, 3, 1, 4)) | 
 |  | 
 |         x = x[:, 1] | 
 |         self.assertFalse(x.is_contiguous()) | 
 |         y = x.unsqueeze(1) | 
 |         self.assertEqual(y, x.contiguous().view(2, 1, 4)) | 
 |         y = x.clone().unsqueeze_(2) | 
 |         self.assertEqual(y, x.contiguous().view(2, 4, 1)) | 
 |  | 
 |     # unit test for special case transposed copy (see ATen/native/Copy.cpp for details) | 
 |     def test_big_transpose(self, device): | 
 |         t = torch.rand(456, 789, device=device) | 
 |         t1 = t.t().contiguous() | 
 |         t2 = torch.from_numpy(t.cpu().numpy().transpose()) | 
 |         self.assertEqual(t1, t2) | 
 |  | 
 |     def test_T(self, device): | 
 |         a = torch.randn(2, 3, 4, device=device) | 
 |         t1 = a.T | 
 |         t2 = a.permute(2, 1, 0) | 
 |         self.assertEqual(t2, t1) | 
 |         b = torch.randn(10, device=device) | 
 |         self.assertEqual(b, b.T) | 
 |  | 
 |     @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) | 
 |     def test_transposes(self, device, dtype): | 
 |         for op in ("T", "H", "mT", "mH", "adjoint"): | 
 |             shapes = ((2, 3), (2, 3, 4)) if op[0] == "m" or op == "adjoint" else ((2, 3),) | 
 |             for shape in shapes: | 
 |                 a = make_tensor(shape, device=device, dtype=dtype) | 
 |                 t1 = getattr(a, op) | 
 |                 if op == "adjoint": | 
 |                     t1 = t1() | 
 |                 t2 = a | 
 |                 t2 = t2.transpose(-2, -1) | 
 |                 if op[-1] == "H" or op == "adjoint": | 
 |                     t2 = t2.conj() | 
 |                 self.assertEqual(t2, t1) | 
 |  | 
 |     @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) | 
 |     def test_transposes_errors(self, device, dtype): | 
 |         for op in ("H", "mT", "mH", "adjoint"): | 
 |             shapes = ((2,), (2, 3, 4)) if op == "H" else ((2,),) | 
 |             for shape in shapes: | 
 |                 a = make_tensor(shape, device=device, dtype=dtype) | 
 |                 with self.assertRaisesRegex(RuntimeError, "only supported on matrices"): | 
 |                     t1 = getattr(a, op) | 
 |                     if op == "adjoint": | 
 |                         t1 = t1() | 
 |  | 
 |     def test_python_types(self, device): | 
 |         a1 = torch.randn((1, 2), device=device, dtype=torch.float64) | 
 |         a2 = torch.randn((1, 2), device=device, dtype=float) | 
 |         self.assertEqual(a1.dtype, a2.dtype) | 
 |  | 
 |         b1 = torch.arange(10, 20, dtype=torch.int64, device=device) | 
 |         b2 = torch.arange(10, 20, dtype=int, device=device) | 
 |         self.assertEqual(b1.dtype, b2.dtype) | 
 |  | 
 |         c1 = torch.tensor([True, False], dtype=torch.bool, device=device) | 
 |         c2 = torch.tensor([True, False], dtype=bool, device=device) | 
 |         self.assertEqual(c1.dtype, c2.dtype) | 
 |  | 
 |     # TODO: is resize best put in test_view_ops? | 
 |     def test_resize_as_preserves_strides(self, device): | 
 |         x = torch.empty(2, 3).t() | 
 |         old_strides = x.stride() | 
 |         x.resize_as_(x) | 
 |         self.assertEqual(x.stride(), old_strides) | 
 |  | 
 |     def test_memory_format_resize_as(self, device): | 
 |         def test_helper(shape, memory_format, device): | 
 |             xc = torch.randn(shape, device=device).contiguous(memory_format=memory_format) | 
 |             flat = torch.randn(xc.numel(), device=device) | 
 |             flat.resize_as_(xc, memory_format=torch.preserve_format) | 
 |             self.assertTrue(flat.is_contiguous(memory_format=memory_format)) | 
 |  | 
 |         test_helper((10, 3, 32, 32), torch.channels_last, device) | 
 |         test_helper((3, 10, 3, 32, 32), torch.channels_last_3d, device) | 
 |  | 
 |     def test_memory_format_resize_(self, device): | 
 |         def test_helper(shape, numel, memory_format, device): | 
 |             flat = torch.randn(numel, device=device) | 
 |             flat.resize_(shape, memory_format=memory_format) | 
 |             self.assertTrue(flat.is_contiguous(memory_format=memory_format)) | 
 |  | 
 |         test_helper((10, 3, 32, 32), 10 * 3 * 32 * 32, torch.channels_last, device) | 
 |         test_helper((3, 10, 3, 32, 32), 3 * 10 * 3 * 32 * 32, torch.channels_last_3d, device) | 
 |  | 
 |     @onlyNativeDeviceTypes | 
 |     @dtypes(torch.int64, torch.float, torch.complex128) | 
 |     def test_transpose_invalid(self, device, dtype): | 
 |         for fn in (torch.swapdims, torch.swapaxes, torch.transpose): | 
 |             shape = _rand_shape(4, min_size=5, max_size=10) | 
 |             x = _generate_input(shape, dtype, device, False) | 
 |  | 
 |             # Invalid `source` and `destination` dimension | 
 |             with self.assertRaisesRegex(IndexError, "Dimension out of range"): | 
 |                 fn(x, 5, 0) | 
 |  | 
 |             with self.assertRaisesRegex(IndexError, "Dimension out of range"): | 
 |                 fn(x, 0, 5) | 
 |  | 
 |     @dtypes(torch.int64, torch.float, torch.complex128) | 
 |     def test_transpose_vs_numpy(self, device, dtype): | 
 |         for fn in (torch.swapdims, torch.swapaxes, torch.transpose): | 
 |             for nd in range(5): | 
 |                 shape = _rand_shape(nd, min_size=5, max_size=10) | 
 |                 x = _generate_input(shape, dtype, device, with_extremal=False) | 
 |                 for random_negative in [True, False]: | 
 |                     for src_dim, dst_dim in permutations(range(nd), r=2): | 
 |                         random_prob = random.random() | 
 |  | 
 |                         if random_negative and random_prob > 0.66: | 
 |                             src_dim = src_dim - nd | 
 |                         elif random_negative and random_prob > 0.33: | 
 |                             dst_dim = dst_dim - nd | 
 |                         elif random_negative: | 
 |                             src_dim = src_dim - nd | 
 |                             dst_dim = dst_dim - nd | 
 |  | 
 |                         partial_map = { | 
 |                             torch.swapdims: partial(torch.swapdims, dim0=src_dim, dim1=dst_dim), | 
 |                             torch.swapaxes: partial(torch.swapaxes, axis0=src_dim, axis1=dst_dim), | 
 |                             torch.transpose: partial(torch.transpose, dim0=src_dim, dim1=dst_dim), | 
 |                         } | 
 |  | 
 |                         torch_fn = partial_map[fn] | 
 |                         np_fn = partial(np.swapaxes, axis1=src_dim, axis2=dst_dim) | 
 |                         self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None) | 
 |  | 
 |             # Move dim to same position | 
 |             x = torch.randn(2, 3, 5, 7, 11) | 
 |             partial_map = { | 
 |                 torch.swapdims: partial(torch.swapdims, dim0=0, dim1=0), | 
 |                 torch.swapaxes: partial(torch.swapaxes, axis0=0, axis1=0), | 
 |                 torch.transpose: partial(torch.transpose, dim0=0, dim1=0), | 
 |             } | 
 |             torch_fn = partial_map[fn] | 
 |             np_fn = partial(np.swapaxes, axis1=0, axis2=0) | 
 |             self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None) | 
 |  | 
 |     def _test_atleast_dim(self, torch_fn, np_fn, device, dtype): | 
 |         for ndims in range(0, 5): | 
 |             shape = _rand_shape(ndims, min_size=5, max_size=10) | 
 |             for n in range(ndims + 1): | 
 |                 for with_extremal in [False, True]: | 
 |                     for contiguous in [False, True]: | 
 |                         # Generate Input. | 
 |                         x = _generate_input(shape, dtype, device, with_extremal) | 
 |                         if contiguous: | 
 |                             x = x.T | 
 |                         self.compare_with_numpy(torch_fn, np_fn, x, device=None, dtype=None) | 
 |  | 
 |                         # Compare sequence input | 
 |                         torch_sequence_x = (x,) * random.randint(3, 10) | 
 |                         np_sequence_x = tuple(np.array(x.detach().cpu().numpy()) for x in torch_sequence_x) | 
 |                         torch_res = torch_fn(*torch_sequence_x) | 
 |                         np_res = np_fn(*np_sequence_x) | 
 |  | 
 |                         torch_res = tuple(x.cpu() for x in torch_res) | 
 |                         np_res = tuple(torch.from_numpy(x) for x in np_res) | 
 |                         self.assertEqual(np_res, torch_res) | 
 |  | 
 |     # TODO: are these view ops? | 
 |     @dtypes(*all_types_and_complex_and(torch.half)) | 
 |     def test_atleast(self, device, dtype): | 
 |         self._test_atleast_dim(torch.atleast_1d, np.atleast_1d, device, dtype) | 
 |         self._test_atleast_dim(torch.atleast_2d, np.atleast_2d, device, dtype) | 
 |         self._test_atleast_dim(torch.atleast_3d, np.atleast_3d, device, dtype) | 
 |  | 
 |     # TODO: OpInfo this | 
 |     def _test_atleast(self, device, torch_fn): | 
 |         # 0-dim | 
 |         s = torch.tensor(0.5, dtype=torch.double, requires_grad=True) | 
 |  | 
 |         gradcheck(lambda x: torch_fn(x), s) | 
 |         gradgradcheck(lambda x: torch_fn(x), s) | 
 |  | 
 |         # 1-dim | 
 |         a = torch.rand(4, dtype=torch.double, requires_grad=True) | 
 |  | 
 |         gradcheck(lambda x: torch_fn(x), a) | 
 |         gradgradcheck(lambda x: torch_fn(x), a) | 
 |  | 
 |         # 2,3,4-dim | 
 |         b = torch.rand(4, 3, dtype=torch.double, requires_grad=True) | 
 |         c = torch.rand(4, 3, 2, dtype=torch.double, requires_grad=True) | 
 |         d = torch.rand(4, 3, 2, 1, dtype=torch.double, requires_grad=True) | 
 |  | 
 |         input_tuple = (s, a, b, c, d) | 
 |         gradcheck(lambda s, w, x, y, z: torch_fn(s, w, x, y, z), input_tuple) | 
 |         gradgradcheck(lambda s, w, x, y, z: torch_fn(s, w, x, y, z), input_tuple) | 
 |  | 
 |     def test_atleast_gradient(self, device): | 
 |         self._test_atleast(device, torch.atleast_1d) | 
 |         self._test_atleast(device, torch.atleast_2d) | 
 |         self._test_atleast(device, torch.atleast_3d) | 
 |  | 
 |  | 
 |     @onlyCPU | 
 |     @dtypes(torch.float) | 
 |     def test_broadcast_tensors(self, device, dtype): | 
 |         x0 = torch.randn(2, 1, 3, dtype=dtype, device=device) | 
 |         x1 = torch.randn(3, dtype=dtype, device=device) | 
 |         x2 = torch.randn(3, 1, dtype=dtype, device=device) | 
 |         expected_size = (2, 3, 3) | 
 |  | 
 |         y0, y1, y2 = torch.broadcast_tensors(x0, x1, x2) | 
 |         self.assertTrue(y0.size() == expected_size) | 
 |         self.assertTrue(y1.size() == expected_size) | 
 |         self.assertTrue(y2.size() == expected_size) | 
 |  | 
 |  | 
 |     @onlyCPU | 
 |     def test_broadcast_shapes(self, device): | 
 |         examples = [(), (1,), (2,), (1, 1), (3, 1), (3, 2), (4, 1, 1), (4, 3, 2)] | 
 |         for s0 in examples: | 
 |             x0 = torch.randn(s0) | 
 |             expected = torch.broadcast_tensors(x0)[0].shape | 
 |             actual = torch.broadcast_shapes(s0) | 
 |             self.assertEqual(expected, actual) | 
 |  | 
 |             for s1 in examples: | 
 |                 x1 = torch.randn(s1) | 
 |                 expected = torch.broadcast_tensors(x0, x1)[0].shape | 
 |                 actual = torch.broadcast_shapes(s0, s1) | 
 |                 self.assertEqual(expected, actual) | 
 |  | 
 |         inputs_list = [[1, 4], [4, 1], [1, 1, 3]] | 
 |         for integral_inputs in inputs_list: | 
 |             res1 = torch.broadcast_shapes(*integral_inputs) | 
 |             res2 = torch.broadcast_tensors(*map(torch.empty, integral_inputs))[0].shape | 
 |             self.assertEqual(res1, res2) | 
 |  | 
 |         inputs_with_neg_vals = [[1, 1, -12], [-1, 1], [-11, ]] | 
 |         for integral_inputs_with_neg_vals in inputs_with_neg_vals: | 
 |             with self.assertRaisesRegex(RuntimeError, "Trying to create tensor with negative dimension"): | 
 |                 torch.broadcast_shapes(*integral_inputs_with_neg_vals) | 
 |  | 
 |         integral_inputs_error_case = [(3, 5), (2, 4, 1)] | 
 |         for error_input in integral_inputs_error_case: | 
 |             with self.assertRaisesRegex(RuntimeError, "Shape mismatch: objects cannot be broadcast to a single shape"): | 
 |                 torch.broadcast_shapes(*error_input) | 
 |  | 
 |         negative_inputs = [(-1,), (1, -12), (4, -11), (-4, 1), (1, 1, -2)] | 
 |         for s0 in negative_inputs: | 
 |             with self.assertRaisesRegex(RuntimeError, "Trying to create tensor with negative dimension"): | 
 |                 torch.broadcast_shapes(s0) | 
 |  | 
 |             for s1 in negative_inputs: | 
 |                 with self.assertRaisesRegex(RuntimeError, "Trying to create tensor with negative dimension"): | 
 |                     torch.broadcast_shapes(s0, s1) | 
 |  | 
 |         float_inputs_error_case = [(1.1, 2.0), (1.1, 1.0)] | 
 |         for error_case in float_inputs_error_case: | 
 |             for float_input in error_case: | 
 |                 with self.assertRaisesRegex(RuntimeError, "Input shapes " | 
 |                                             "should be of type ints, a tuple of ints, or a list of ints"): | 
 |                     torch.broadcast_shapes(float_input) | 
 |  | 
 |         diff_input_types = [(1, (5,)), (3, (1,)), (1, (3, 4))] | 
 |         for s0 in diff_input_types: | 
 |             res1 = torch.broadcast_shapes(*s0) | 
 |             res2 = torch.broadcast_tensors(*map(torch.empty, s0))[0].shape | 
 |             self.assertEqual(res1, res2) | 
 |  | 
 |     # Skip BFloat16 since numpy does not support it | 
 |     @dtypes(*all_types_and_complex_and(torch.half, torch.bool)) | 
 |     def test_broadcast_to(self, device, dtype): | 
 |         def can_broadcast(s0, s1): | 
 |             # s0.dim() <= s1.dim(), reverse s0 and s1 to compare trailing dimension | 
 |             s0 = tuple(reversed(s0)) | 
 |             s1 = tuple(reversed(s1)) | 
 |             for i in range(len(s0)): | 
 |                 if s0[i] != 1 and s0[i] != s1[i]: | 
 |                     return False | 
 |             return True | 
 |  | 
 |         sizes = ( | 
 |             (), (1,), (2,), (1, 1), (3, 1), (3, 2), (4, 1, 1), (4, 3, 2) | 
 |         ) | 
 |         for s0, s1 in combinations(sizes, r=2): | 
 |             t = make_tensor(s0, dtype=dtype, device=device, low=-9, high=9) | 
 |             t_np = t.cpu().numpy() | 
 |  | 
 |             if can_broadcast(s0, s1): | 
 |                 res = torch.broadcast_to(t, s1) | 
 |                 np_res = np.broadcast_to(t_np, s1) | 
 |                 self.assertEqual(res, np_res) | 
 |             else: | 
 |                 with self.assertRaisesRegex(RuntimeError, | 
 |                                             r"The expanded size of the tensor \(\d\) " | 
 |                                             r"must match the existing size \(\d\)"): | 
 |                     torch.broadcast_to(t, s1) | 
 |  | 
 |     def test_view(self, device): | 
 |         tensor = torch.rand(15, device=device) | 
 |         template = torch.rand(3, 5, device=device) | 
 |         empty = torch.empty(0, device=device) | 
 |         target = template.size() | 
 |         self.assertEqual(tensor.view_as(template).size(), target) | 
 |         self.assertEqual(tensor.view(3, 5).size(), target) | 
 |         self.assertEqual(tensor.view(torch.Size([3, 5])).size(), target) | 
 |         self.assertEqual(tensor.view(-1, 5).size(), target) | 
 |         self.assertEqual(tensor.view(3, -1).size(), target) | 
 |         tensor_view = tensor.view(5, 3) | 
 |         tensor_view.fill_(random.uniform(0, 1)) | 
 |         self.assertEqual(empty.view_as(empty), empty) | 
 |         self.assertEqual(empty.view(0), empty) | 
 |         self.assertEqual(empty.view(0, 3, 0, 1).size(), torch.Size([0, 3, 0, 1])) | 
 |         self.assertEqual(empty.view(0, 3, 0, 1).view(0), empty) | 
 |  | 
 |         # test size inference with empty tensors | 
 |         self.assertEqual(empty.view(-1).size(), torch.Size([0])) | 
 |         self.assertEqual(empty.view(10, 3, -1).size(), torch.Size([10, 3, 0])) | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, r"because the unspecified dimension size -1 can be any value"): | 
 |             empty.view(-1, 0) | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, r"because the unspecified dimension size -1 can be any value"): | 
 |             empty.view(3, 0, -1, 0) | 
 |  | 
 |         self.assertRaises(RuntimeError, lambda: tensor.view(15, 0)) | 
 |         self.assertRaises(RuntimeError, lambda: tensor.view(7, -1)) | 
 |         self.assertRaises(RuntimeError, lambda: tensor.view(15, -1, -1)) | 
 |  | 
 |         # test view when tensor is not contiguous in every dimension, but only | 
 |         # contiguous dimensions are touched. | 
 |         tensor = torch.rand(4, 2, 5, 1, 6, 2, 9, 3, device=device).transpose(-1, 2).transpose(-2, 3) | 
 |         # size:                      [   4,    2,    3,    9,    6,    2,    1,    5] | 
 |         # stride:                    [3840, 1620,    1,    3,   54,   27,  324,  324] | 
 |         # contiguous dim chunks:     [__________, ____, ____, __________, ____, ____] | 
 |         # merging 1 to chunk after:  [__________, ____, ____, __________, __________] | 
 |         contig_tensor = tensor.clone() | 
 |         # [4, 2] => [8, 1] | 
 |         # [3] => [3] | 
 |         # [9] => [3, 3] | 
 |         # [6, 2] => [4, 1, 3] | 
 |         # [1, 5] => [5] | 
 |         view_size = [8, 1, 3, 3, 3, 4, 1, 3, 5] | 
 |         self.assertEqual(tensor.view(*view_size), contig_tensor.view(*view_size)) | 
 |         # [4, 2] => [2, 4] | 
 |         # [3] => [3] | 
 |         # [9] => [1, 9] | 
 |         # [6, 2] => [2, 2, 3] | 
 |         # [1, 5] => [5, 1] | 
 |         view_size = [2, 4, 3, 1, 9, 2, 2, 3, 5, 1] | 
 |         self.assertEqual(tensor.view(*view_size), contig_tensor.view(*view_size)) | 
 |         # adding size 1 dims | 
 |         view_size = [1, 1, 2, 1, 4, 3, 1, 1, 9, 1, 2, 1, 2, 3, 1, 5, 1, 1] | 
 |         self.assertEqual(tensor.view(*view_size), contig_tensor.view(*view_size)) | 
 |  | 
 |         # invalid views | 
 |         self.assertRaises(RuntimeError, lambda: tensor.view(-1)) | 
 |         # crossing [4, 2], [3] | 
 |         self.assertRaises(RuntimeError, lambda: tensor.view(24, 9, 6, 2, 1, 5)) | 
 |         # crossing [6, 2], [1, 5] | 
 |         self.assertRaises(RuntimeError, lambda: tensor.view(8, 3, 9, 6, 10)) | 
 |         # crossing [9], [6, 2] | 
 |         self.assertRaises(RuntimeError, lambda: tensor.view(8, 3, 54, 2, 1, 5)) | 
 |  | 
 |         # view with stride 0 dims | 
 |         tensor = torch.empty(1, 1, device=device).expand(3, 4)  # all dims are contiguous | 
 |         contig_tensor = tensor.clone() | 
 |         self.assertEqual(tensor.view(-1), contig_tensor.view(-1)) | 
 |         self.assertEqual(tensor.view(1, -1, 1), contig_tensor.view(1, -1, 1)) | 
 |         self.assertEqual(tensor.view(-1, 1), contig_tensor.view(-1, 1)) | 
 |         self.assertEqual(tensor.view(6, 2, 1), contig_tensor.view(6, 2, 1)) | 
 |         self.assertEqual(tensor.view(1, 6, 2, 1), contig_tensor.view(1, 6, 2, 1)) | 
 |  | 
 |     @dtypes(*all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool)) | 
 |     def test_reshape_view_semantics(self, device, dtype): | 
 |         tensor = make_tensor((15, 4), dtype=dtype, device=device) | 
 |         target = (20, 3) | 
 |  | 
 |         # Cases where the tensor can be returned as a view. | 
 |         view_tensor = tensor.reshape(target) | 
 |         self.assertEqual((view_tensor.size()), target) | 
 |         self.assertEqual(tensor.storage().data_ptr(), view_tensor.storage().data_ptr()) | 
 |  | 
 |         # Cases where the tensor must be copied (transpose makes it non-contiguous forcing | 
 |         # the copy). | 
 |         copy_tensor = tensor.transpose(0, 1).reshape(target) | 
 |         self.assertEqual(copy_tensor.size(), target) | 
 |         self.assertNotEqual(tensor.storage().data_ptr(), copy_tensor.storage().data_ptr()) | 
 |  | 
 |     def test_contiguous(self, device): | 
 |         x = torch.randn(1, 16, 5, 5, device=device) | 
 |         self.assertTrue(x.is_contiguous()) | 
 |         stride = list(x.stride()) | 
 |         stride[0] = 20 | 
 |         # change the stride in dimension 0. the tensor is still contiguous because size[0] is 1 | 
 |         x.set_(x.storage(), 0, x.size(), stride) | 
 |         self.assertTrue(x.is_contiguous()) | 
 |  | 
 |     @onlyNativeDeviceTypes | 
 |     # Skip BFloat16 since numpy does not support it | 
 |     @dtypes(*all_types_and_complex_and(torch.half, torch.bool)) | 
 |     def test_tensor_split_sections(self, device, dtype): | 
 |         input_sizes = [ | 
 |             (0,), | 
 |             (10,), | 
 |             (10, 0), | 
 |             (0, 10), | 
 |             (4, 10), | 
 |             (12, 3), | 
 |         ] | 
 |         for input_size in input_sizes: | 
 |             a_base = make_tensor(input_size, dtype=dtype, device=device, low=-9, high=9) | 
 |             # Run tests on transposed input if it has at least 2 dims | 
 |             for a in [a_base, a_base.t()] if a_base.dim() > 2 else [a_base]: | 
 |                 a_n = a.cpu().numpy() | 
 |                 for dim in range(-a.dim(), a.dim()): | 
 |                     for sections in range(1, 2 * a.size(dim)): | 
 |                         msg = f'input_size {input_size}, sections {sections}, dim {dim}' | 
 |                         result1 = torch.tensor_split(a, sections, dim) | 
 |                         result2 = torch.tensor_split(a, torch.tensor(sections, dtype=torch.int64), dim) | 
 |                         for r1, r2 in zip(result1, result2): | 
 |                             self.assertEqual(r1.device, torch.device(device), msg=msg) | 
 |                             self.assertEqual(r1.dtype, dtype, msg=msg) | 
 |                             self.assertEqual(r2.device, torch.device(device), msg=msg) | 
 |                             self.assertEqual(r2.dtype, dtype, msg=msg) | 
 |                         result_n = np.array_split(a_n, sections, dim) | 
 |                         self.assertEqual(result_n, result1, msg=msg) | 
 |                         self.assertEqual(result_n, result2, msg=msg) | 
 |  | 
 |     @onlyNativeDeviceTypes | 
 |     # Skip BFloat16 since numpy does not support it | 
 |     @dtypes(*all_types_and_complex_and(torch.half, torch.bool)) | 
 |     def test_tensor_split_indices(self, device, dtype): | 
 |         input_sizes = [ | 
 |             (0,), | 
 |             (10,), | 
 |             (10, 0), | 
 |             (0, 10), | 
 |             (4, 10), | 
 |             (12, 3), | 
 |         ] | 
 |         indices_args = [ | 
 |             (), | 
 |             (0,), | 
 |             (3,), | 
 |             (10,), | 
 |             (-1,), | 
 |             (-10,), | 
 |             (2, -1), | 
 |             (3, 4, 10), | 
 |             (0, -1, 0, 10), | 
 |             (1, 5, 2, 8), | 
 |         ] | 
 |         for input_size in input_sizes: | 
 |             a_base = make_tensor(input_size, dtype=dtype, device=device, low=-9, high=9) | 
 |             # Run tests on transposed input if it has at least 2 dims | 
 |             for a in [a_base, a_base.t()] if a_base.dim() > 2 else [a_base]: | 
 |                 a_n = a.cpu().numpy() | 
 |                 for dim in range(-a.dim(), a.dim()): | 
 |                     for indices in indices_args: | 
 |                         result_1 = torch.tensor_split(a, indices, dim) | 
 |                         result_2 = torch.tensor_split(a, torch.tensor(indices, dtype=torch.int64), dim) | 
 |  | 
 |                         msg = f'input_size {input_size}, indices {indices}, dim {dim}' | 
 |                         for r1, r2 in zip(result_1, result_2): | 
 |                             self.assertEqual(r1.device, torch.device(device), msg=msg) | 
 |                             self.assertEqual(r1.dtype, dtype, msg=msg) | 
 |                             self.assertEqual(r2.device, torch.device(device), msg=msg) | 
 |                             self.assertEqual(r2.dtype, dtype, msg=msg) | 
 |  | 
 |                         result_n = np.array_split(a_n, indices, dim) | 
 |                         self.assertEqual(result_n, result_1, msg=msg) | 
 |                         self.assertEqual(result_n, result_2, msg=msg) | 
 |  | 
 |     @onlyNativeDeviceTypes | 
 |     def test_tensor_split_errors(self, device): | 
 |         S = 10 | 
 |         test_cases = [ | 
 |             # input size, sections or indices, dim, error type, error message, numpy error type | 
 |             [(S,), 10, 1, IndexError, r'Dimension out of range', IndexError], | 
 |             [(), 10, 0, RuntimeError, r'tensor_split expected at least a 1-dimensional tensor, ' | 
 |                 + 'but got a tensor with 0 dims', IndexError], | 
 |             [(S,), (10,), 1, IndexError, r'Dimension out of range', IndexError], | 
 |             [(), (10,), 0, RuntimeError, r'tensor_split expected at least a 1-dimensional tensor, ' | 
 |                 + 'but got a tensor with 0 dims', IndexError], | 
 |             [(S,), 0, 0, RuntimeError, r'number of sections must be larger than 0, got 0', ValueError], | 
 |             [(S,), -1, 0, RuntimeError, r'number of sections must be larger than 0, got -1', ValueError], | 
 |         ] | 
 |         for input_size, sections_or_indices, dim, err, err_msg, numpy_err in test_cases: | 
 |             a = torch.randn(input_size, device=device) | 
 |             msg = f'input_size {input_size}, sections_or_indices {sections_or_indices}, dim {dim}' | 
 |             with self.assertRaisesRegex(err, err_msg, msg=msg): | 
 |                 torch.tensor_split(a, sections_or_indices, dim) | 
 |             with self.assertRaisesRegex(err, err_msg, msg=msg): | 
 |                 torch.tensor_split(a, torch.tensor(sections_or_indices), dim) | 
 |             with self.assertRaises(numpy_err, msg=msg): | 
 |                 np.array_split(a.cpu().numpy(), sections_or_indices, dim) | 
 |  | 
 |         # addtional tests for tensor_split with tensor_indices_or_sections | 
 |         with self.assertRaisesRegex(RuntimeError, | 
 |                                     r'tensor_split expected tensor_indices_or_sections to have dtype of long, but got Float'): | 
 |             torch.tensor_split(a, torch.tensor(1.1), dim) | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, | 
 |                                     r'tensor_split expected tensor_indices_or_sections to be a' | 
 |                                     + ' zero-dimensional or one-dimensional tensor, but got a tensor with 2 dims'): | 
 |             torch.tensor_split(torch.rand(S, device=device), torch.tensor(((1,),)), 0) | 
 |  | 
 |     def test_resize_all_dtypes_and_devices(self, device): | 
 |         shape = (2, 2) | 
 |         for dt in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool): | 
 |             x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) | 
 |             x.resize_(shape) | 
 |             self.assertEqual(shape, x.shape) | 
 |  | 
 |     def test_resize_as_all_dtypes_and_devices(self, device): | 
 |         for dt in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool): | 
 |             x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) | 
 |             y = torch.tensor([[1, 2, 3], [4, 5, 6]], dtype=dt, device=device) | 
 |             x.resize_as_(y) | 
 |             self.assertEqual(y.shape, x.shape) | 
 |  | 
 |     @onlyNativeDeviceTypes | 
 |     def test_resize_overflow(self, device): | 
 |         x = torch.empty((), dtype=torch.float64) | 
 |         with self.assertRaisesRegex(RuntimeError, 'Storage size calculation overflowed'): | 
 |             x.resize_([2, 4, 2**29, 2**29]) | 
 |         with self.assertRaisesRegex(RuntimeError, 'overflow'): | 
 |             x.resize_([8, 8, 2**29, 2**29]) | 
 |         with self.assertRaisesRegex(RuntimeError, 'Stride calculation overflowed'): | 
 |             x.resize_([0, 4, 2305843009213693952]) | 
 |  | 
 |     def test_view_all_dtypes_and_devices(self, device): | 
 |         for dt in all_types_and_complex_and(torch.half, torch.bfloat16, torch.bool): | 
 |             x = torch.tensor([[1, 2], [3, 4], [5, 6]], dtype=dt, device=device) | 
 |             self.assertEqual(x.view(6).shape, [6]) | 
 |  | 
 |     @onlyCPU | 
 |     def test_conj_neg_view_numpy_error(self, device): | 
 |         self.assertRaisesRegex(RuntimeError, "has conjugate bit set", lambda: torch.tensor([1 + 2j]).conj().numpy()) | 
 |         self.assertRaisesRegex(RuntimeError, "has negative bit set", lambda: torch.tensor([1 + 2j]).conj().imag.numpy()) | 
 |         self.assertRaisesRegex(RuntimeError, "not supported for conjugate view tensors", | 
 |                                lambda: torch.tensor([1 + 2j]).conj().view(torch.float64)) | 
 |         self.assertRaisesRegex(RuntimeError, "not supported for tensors with negative bit set", | 
 |                                lambda: torch.tensor([1 + 2j]).conj().imag.view(torch.int32)) | 
 |  | 
 |     @onlyCPU | 
 |     def test_crow_col_indices(self, device): | 
 |         crow_indices = (0, 1, 2) | 
 |         col_indices = (1, 0) | 
 |         values = (1, 2) | 
 |         t = torch.sparse_csr_tensor(crow_indices, col_indices, values, size=(2, 2)) | 
 |         # This is the test. If crow_indices is not a view op it'll | 
 |         # trigger an internal assert due to use count greater than 1 | 
 |         # in debug build. | 
 |         t.crow_indices() | 
 |         t.col_indices() | 
 |  | 
 | instantiate_device_type_tests(TestViewOps, globals(), include_lazy=True) | 
 | instantiate_device_type_tests(TestOldViewOps, globals()) | 
 |  | 
 | if __name__ == '__main__': | 
 |     run_tests() |