|  | import unittest | 
|  | from torch.testing._internal.common_utils import TestCase, run_tests, TEST_NUMPY | 
|  | from torch.testing._internal.common_cuda import TEST_CUDA | 
|  | from collections import namedtuple, OrderedDict | 
|  | import itertools | 
|  | import functools | 
|  | import torch | 
|  | from torch import Tensor | 
|  | import torch.nn.functional as F | 
|  | from multiprocessing.reduction import ForkingPickler | 
|  | import pickle | 
|  | import io | 
|  | import sys | 
|  | import warnings | 
|  |  | 
|  |  | 
|  | def pass_name_to_python_arg_parser(name): | 
|  | x = torch.empty(2, names=(name,)) | 
|  |  | 
|  |  | 
|  | def flatten(lst): | 
|  | return [item for sublist in lst for item in sublist] | 
|  |  | 
|  |  | 
|  | Function = namedtuple('TestCase', ['name', 'lambd']) | 
|  |  | 
|  |  | 
|  | def parse_compressed_namedshape(string): | 
|  | # This is a metalanguage for describing a shape of a tensor compactly. | 
|  | # 'N:3,C:2' -> size = [3, 2], names: ['N', 'C'] | 
|  | # 'None:3,None:2' -> size = [3, 2], names: ['None', 'None'] | 
|  | # '3,2' -> size = [3, 2], names=None passed to ctor. | 
|  | def parse_name(maybe_name): | 
|  | maybe_name = maybe_name.strip() | 
|  | if maybe_name == 'None': | 
|  | return None | 
|  | return maybe_name | 
|  |  | 
|  | string = string.strip() | 
|  |  | 
|  | # '' -> size: [], names:None | 
|  | if len(string) == 0: | 
|  | return None, [] | 
|  |  | 
|  | # '3, 2' -> size = [3, 2], None names. | 
|  | if ':' not in string: | 
|  | return None, [int(size) for size in string.split(',')] | 
|  |  | 
|  | dims = string.split(',') | 
|  | tuples = [dim.split(':') for dim in dims] | 
|  | return zip(*[(parse_name(name), int(size)) for name, size in tuples]) | 
|  |  | 
|  |  | 
|  | def create(namedshape, factory=torch.randn): | 
|  | # namedshape: str | 
|  | names, shape = parse_compressed_namedshape(namedshape) | 
|  | return factory(shape, names=names) | 
|  |  | 
|  |  | 
|  | def out_fn(operator): | 
|  | @functools.wraps(operator) | 
|  | def fn(*inputs): | 
|  | return operator(*inputs[1:], out=inputs[0]) | 
|  | return fn | 
|  |  | 
|  |  | 
|  | class TestNamedTensor(TestCase): | 
|  | def test_aaa_must_run_first_check_experimental_warning(self): | 
|  | # TODO(rzou): It would be nice for this to be a "real" python warning. | 
|  | # Right now this error message only prints once and doesn't respect | 
|  | # warnings.simplefilter behavior (where python users can control whether | 
|  | # or not to display warnings once, all the time, or never). | 
|  | with warnings.catch_warnings(record=True) as warns: | 
|  | x = torch.randn(3, 3, names=('N', 'C')) | 
|  | self.assertEqual(len(warns), 1) | 
|  | self.assertTrue(str(warns[0].message).startswith( | 
|  | 'Named tensors and all their associated APIs are an experimental feature')) | 
|  |  | 
|  | def test_trivial(self): | 
|  | pass | 
|  |  | 
|  | def _test_name_inference(self, op, args=(), expected_names=(), device='cpu', | 
|  | maybe_raises_regex=None): | 
|  | casted_args = [arg.to(device) if isinstance(arg, torch.Tensor) else arg | 
|  | for arg in args] | 
|  | if maybe_raises_regex is not None: | 
|  | with self.assertRaisesRegex(RuntimeError, maybe_raises_regex): | 
|  | result = op(*args) | 
|  | return | 
|  | result = op(*args) | 
|  | self.assertEqual(result.names, expected_names, | 
|  | msg='Name inference for {} on device {} failed'.format( | 
|  | op.__name__, device)) | 
|  |  | 
|  | # TODO(rzou): Some form of this check should be added to self.assertEqual. | 
|  | # Right now I don't know what it should look like. | 
|  | def assertTensorDataAndNamesEqual(self, x, y): | 
|  | self.assertEqual(x.names, y.names) | 
|  | unnamed_x = x.rename(None) | 
|  | unnamed_y = y.rename(None) | 
|  | self.assertEqual(unnamed_x, unnamed_y) | 
|  |  | 
|  | def _test_factory(self, factory, device): | 
|  | x = factory([], device=device) | 
|  | self.assertEqual(x.names, ()) | 
|  |  | 
|  | x = factory(1, 2, 3, device=device) | 
|  | self.assertEqual(x.names, (None, None, None)) | 
|  |  | 
|  | x = factory(1, 2, 3, names=None, device=device) | 
|  | self.assertEqual(x.names, (None, None, None)) | 
|  |  | 
|  | x = factory(1, 2, 3, names=('N', 'T', 'D'), device=device) | 
|  | self.assertEqual(x.names, ('N', 'T', 'D')) | 
|  |  | 
|  | x = factory(1, 2, 3, names=('N', None, 'D'), device=device) | 
|  | self.assertEqual(x.names, ('N', None, 'D')) | 
|  |  | 
|  | x = factory(1, 2, 3, names=('_1', 'batch9', 'BATCH_5'), device=device) | 
|  | self.assertEqual(x.names, ('_1', 'batch9', 'BATCH_5')) | 
|  |  | 
|  | with self.assertRaisesRegex(RuntimeError, | 
|  | 'a valid identifier contains only'): | 
|  | x = factory(2, names=('1',), device=device) | 
|  |  | 
|  | with self.assertRaisesRegex(RuntimeError, | 
|  | 'a valid identifier contains only'): | 
|  | x = factory(2, names=('?',), device=device) | 
|  |  | 
|  | with self.assertRaisesRegex(RuntimeError, 'Number of names'): | 
|  | x = factory(2, 1, names=('N',), device=device) | 
|  |  | 
|  | with self.assertRaisesRegex(TypeError, 'invalid combination of arguments'): | 
|  | x = factory(2, 1, names='N', device=device) | 
|  |  | 
|  | with self.assertRaisesRegex(RuntimeError, 'construct a tensor with duplicate names'): | 
|  | x = factory(2, 1, 1, names=('N', 'C', 'N'), device=device) | 
|  |  | 
|  | names64 = ['A' * i for i in range(1, 65)] | 
|  | x = factory([1] * 64, names=names64, device=device) | 
|  | self.assertEqual(x.names, names64) | 
|  |  | 
|  | with self.assertRaisesRegex( | 
|  | RuntimeError, | 
|  | 'only support up to 64 dims'): | 
|  | names65 = ['A' * i for i in range(1, 66)] | 
|  | x = factory([1] * 65, names=names64, device=device) | 
|  |  | 
|  | def test_none_names_refcount(self, N=10): | 
|  | def scope(): | 
|  | unnamed = torch.empty(2, 3) | 
|  | unnamed.names  # materialize [None, None] | 
|  |  | 
|  | prev_none_refcnt = sys.getrefcount(None) | 
|  | # Ran it N times to reduce flakiness | 
|  | [scope() for i in range(N)] | 
|  | after_none_refcnt = sys.getrefcount(None) | 
|  | self.assertTrue(after_none_refcnt - prev_none_refcnt < N / 2, | 
|  | msg='Using tensor.names should not change ' | 
|  | 'the refcount of Py_None') | 
|  |  | 
|  | def test_has_names(self): | 
|  | unnamed = torch.empty(2, 3) | 
|  | none_named = torch.empty(2, 3, names=(None, None)) | 
|  | partially_named = torch.empty(2, 3, names=('N', None)) | 
|  | fully_named = torch.empty(2, 3, names=('N', 'C')) | 
|  |  | 
|  | self.assertFalse(unnamed.has_names()) | 
|  | self.assertFalse(none_named.has_names()) | 
|  | self.assertTrue(partially_named.has_names()) | 
|  | self.assertTrue(fully_named.has_names()) | 
|  |  | 
|  | def test_py3_ellipsis(self): | 
|  | tensor = torch.randn(2, 3, 5, 7) | 
|  | output = tensor.refine_names('N', ..., 'C') | 
|  | self.assertEqual(output.names, ['N', None, None, 'C']) | 
|  |  | 
|  | def test_refine_names(self): | 
|  | # Unnamed tensor -> Unnamed tensor | 
|  | self._test_name_inference(Tensor.refine_names, | 
|  | [create('None:1,None:2,None:3'), 'N', 'C', 'H'], | 
|  | ['N', 'C', 'H']) | 
|  |  | 
|  | # Named tensor -> Named tensor | 
|  | self._test_name_inference(Tensor.refine_names, | 
|  | [create('N:1,C:2,H:3'), 'N', 'C', 'H'], | 
|  | ['N', 'C', 'H']) | 
|  |  | 
|  | # Partially named tensor -> named tensor | 
|  | self._test_name_inference(Tensor.refine_names, | 
|  | [create('None:1,C:2,None:3'), None, 'C', 'H'], | 
|  | [None, 'C', 'H']) | 
|  |  | 
|  | # Too few names | 
|  | self._test_name_inference(Tensor.refine_names, | 
|  | [create('None:2,None:3'), 'N', 'C', 'H'], | 
|  | maybe_raises_regex="different number of dims") | 
|  |  | 
|  | # Cannot change Tensor[D] to Tensor[N] | 
|  | self._test_name_inference(Tensor.refine_names, | 
|  | [create('D:3'), 'N'], | 
|  | maybe_raises_regex="is different from") | 
|  |  | 
|  | # Cannot change Tensor[D] to Tensor[None] | 
|  | self._test_name_inference(Tensor.refine_names, | 
|  | [create('D:3'), None], | 
|  | maybe_raises_regex="'D' is more specific than None") | 
|  |  | 
|  | # globbing behavior exists | 
|  | self._test_name_inference(Tensor.refine_names, | 
|  | [create('None:1,None:1,None:2,None:3'), '...', 'C', 'H'], | 
|  | [None, None, 'C', 'H']) | 
|  |  | 
|  | def test_detach(self): | 
|  | names = ['N'] | 
|  | self._test_name_inference( | 
|  | Tensor.detach_, | 
|  | [torch.randn(3, requires_grad=True, names=names)], | 
|  | names) | 
|  | self._test_name_inference( | 
|  | Tensor.detach, | 
|  | [torch.randn(3, requires_grad=True, names=names)], | 
|  | names) | 
|  |  | 
|  | def test_index_fill(self): | 
|  | for device in torch.testing.get_all_device_types(): | 
|  | expected_names = ('N', 'C') | 
|  | x = torch.randn(3, 5, device=device, names=expected_names) | 
|  |  | 
|  | output = x.index_fill_('C', torch.tensor([0, 1], device=device), 5) | 
|  | self.assertEqual(output.names, expected_names) | 
|  |  | 
|  | output = x.index_fill_('C', torch.tensor([0, 1], device=device), torch.tensor(4.)) | 
|  | self.assertEqual(output.names, expected_names) | 
|  |  | 
|  | output = x.index_fill('C', torch.tensor([0, 1], device=device), 5) | 
|  | self.assertEqual(output.names, expected_names) | 
|  |  | 
|  | output = x.index_fill('C', torch.tensor([0, 1], device=device), torch.tensor(4.)) | 
|  | self.assertEqual(output.names, expected_names) | 
|  |  | 
|  | def test_equal(self): | 
|  | for device in torch.testing.get_all_device_types(): | 
|  | tensor = torch.randn(2, 3, device=device) | 
|  | other = tensor.clone() | 
|  |  | 
|  | self.assertTrue(torch.equal(tensor.rename('N', 'C'), other.rename('N', 'C'))) | 
|  | self.assertFalse(torch.equal(tensor.rename('M', 'C'), other.rename('N', 'C'))) | 
|  | self.assertFalse(torch.equal(tensor.rename(None, 'C'), other.rename('N', 'C'))) | 
|  |  | 
|  | def test_squeeze(self): | 
|  | x = create('N:3,C:1,H:1,W:1') | 
|  | output = x.squeeze('C') | 
|  | self.assertEqual(output.names, ['N', 'H', 'W']) | 
|  |  | 
|  | output = x.squeeze() | 
|  | self.assertEqual(output.names, ['N']) | 
|  |  | 
|  | def test_repr(self): | 
|  | named_tensor = torch.zeros(2, 3).rename_('N', 'C') | 
|  | expected = "tensor([[0., 0., 0.],\n        [0., 0., 0.]], names=('N', 'C'))" | 
|  | self.assertEqual(repr(named_tensor), expected) | 
|  |  | 
|  | unnamed_tensor = torch.zeros(2, 3) | 
|  | expected = "tensor([[0., 0., 0.],\n        [0., 0., 0.]])" | 
|  | self.assertEqual(repr(unnamed_tensor), expected) | 
|  |  | 
|  | none_named_tensor = torch.zeros(2, 3).rename_(None, None) | 
|  | self.assertEqual(repr(none_named_tensor), expected) | 
|  |  | 
|  | def test_diagonal(self): | 
|  | named_tensor = torch.zeros(2, 3, 5, 7, names=list('ABCD')) | 
|  | self.assertEqual(named_tensor.diagonal().names, ['C', 'D', None]) | 
|  | self.assertEqual(named_tensor.diagonal(1, 3).names, ['A', 'C', None]) | 
|  |  | 
|  | self.assertEqual(named_tensor.diagonal(outdim='E', dim1='B', dim2='D').names, | 
|  | ['A', 'C', 'E']) | 
|  |  | 
|  | def test_max_pooling(self): | 
|  | def check_tuple_return(op, inputs, expected_names): | 
|  | values, indices = op(*inputs) | 
|  | self.assertEqual(values.names, expected_names) | 
|  | self.assertEqual(indices.names, expected_names) | 
|  |  | 
|  | for device in torch.testing.get_all_device_types(): | 
|  |  | 
|  | named_tensor_1d = torch.zeros(2, 3, 5, device=device, names=list('ABC')) | 
|  | named_tensor_2d = torch.zeros(2, 3, 5, 7, device=device, names=list('ABCD')) | 
|  | named_tensor_3d = torch.zeros(2, 3, 5, 7, 9, device=device, names=list('ABCDE')) | 
|  |  | 
|  | self.assertEqual(F.max_pool1d(named_tensor_1d, 2).names, named_tensor_1d.names) | 
|  | self.assertEqual(F.max_pool2d(named_tensor_2d, [2, 2]).names, named_tensor_2d.names) | 
|  | self.assertEqual(F.max_pool3d(named_tensor_3d, [2, 2, 2]).names, named_tensor_3d.names) | 
|  |  | 
|  | check_tuple_return(F.max_pool1d_with_indices, [named_tensor_1d, 2], named_tensor_1d.names) | 
|  | check_tuple_return(F.max_pool2d_with_indices, [named_tensor_2d, [2, 2]], named_tensor_2d.names) | 
|  | check_tuple_return(F.max_pool3d_with_indices, [named_tensor_3d, [2, 2, 2]], named_tensor_3d.names) | 
|  |  | 
|  | def test_no_save_support(self): | 
|  | named_tensor = torch.zeros(2, 3, names=('N', 'C')) | 
|  | buf = io.BytesIO() | 
|  | with self.assertRaisesRegex(RuntimeError, "NYI"): | 
|  | torch.save(named_tensor, buf) | 
|  |  | 
|  | def test_no_pickle_support(self): | 
|  | named_tensor = torch.zeros(2, 3, names=('N', 'C')) | 
|  | with self.assertRaisesRegex(RuntimeError, "NYI"): | 
|  | serialized = pickle.dumps(named_tensor) | 
|  |  | 
|  | def test_no_multiprocessing_support(self): | 
|  | named_tensor = torch.zeros(2, 3, names=('N', 'C')) | 
|  | buf = io.BytesIO() | 
|  | with self.assertRaisesRegex(RuntimeError, "NYI"): | 
|  | ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(named_tensor) | 
|  |  | 
|  | def test_big_tensor_repr_has_names(self): | 
|  | def check_repr(named_tensor): | 
|  | unnamed_tensor = named_tensor.rename(None) | 
|  | names_tag = 'names={}'.format(named_tensor.names) | 
|  | self.assertIn(names_tag, repr(named_tensor)) | 
|  |  | 
|  | check_repr(torch.randn(128, 3, 64, 64, names=('N', 'C', 'H', 'W'))) | 
|  |  | 
|  | def test_noncontig_contiguous(self): | 
|  | # This type of contiguous is special-cased and therefore needs its own test | 
|  | for device in torch.testing.get_all_device_types(): | 
|  | x = torch.randn(2, 3, device=device).t().rename_('N', 'C') | 
|  | self.assertEqual(x.contiguous().names, ('N', 'C')) | 
|  |  | 
|  | def test_copy_transpose(self): | 
|  | # This type of copy is special-cased and therefore needs its own test | 
|  | def _test(self_names, other_names, expected_names): | 
|  | x = torch.empty(2, 5, names=self_names) | 
|  | y = torch.empty(5, 2).t().rename_(*other_names) | 
|  | x.copy_(y) | 
|  | self.assertEqual(x.names, expected_names) | 
|  |  | 
|  | _test(('N', 'C'), ('N', 'C'), ('N', 'C')) | 
|  | _test(None, ('N', 'C'), ('N', 'C')) | 
|  |  | 
|  | def test_rename_(self): | 
|  | tensor = torch.empty(1, 1, names=('N', 'C')) | 
|  | self.assertEqual(tensor.rename_(None).names, (None, None)) | 
|  | self.assertEqual(tensor.rename_('H', 'W').names, ('H', 'W')) | 
|  | with self.assertRaisesRegex(RuntimeError, 'Number of names'): | 
|  | tensor.rename_('N', 'C', 'W') | 
|  | with self.assertRaisesRegex(RuntimeError, 'duplicate names'): | 
|  | tensor.rename_('N', 'N') | 
|  |  | 
|  | def test_rename(self): | 
|  | tensor = torch.empty(1, 1, names=('N', 'C')) | 
|  |  | 
|  | self.assertEqual(tensor.rename(None).names, (None, None)) | 
|  | self.assertEqual(tensor.rename('H', 'W').names, ('H', 'W')) | 
|  |  | 
|  | # Check that we didn't modify tensor.names | 
|  | self.assertEqual(tensor.names, ('N', 'C')) | 
|  |  | 
|  | with self.assertRaisesRegex(RuntimeError, 'Number of names'): | 
|  | tensor.rename('N', 'C', 'W') | 
|  | with self.assertRaisesRegex(RuntimeError, 'duplicate names'): | 
|  | tensor.rename('N', 'N') | 
|  |  | 
|  | with self.assertRaisesRegex(RuntimeError, 'either positional args or keyword args'): | 
|  | tensor.rename(None, N='batch') | 
|  |  | 
|  | # rename returns a view on the tensor | 
|  | self.assertEqual(tensor.rename('H', 'W').data_ptr(), tensor.data_ptr()) | 
|  | self.assertEqual(tensor.rename(None).data_ptr(), tensor.data_ptr()) | 
|  |  | 
|  | def test_rename_globber(self): | 
|  | scalar = torch.randn([]) | 
|  | unnamed_tensor = torch.empty(1, 1, 1, 1) | 
|  | named_tensor = torch.empty(1, 1, 1, 1, names=('N', 'C', 'H', 'W')) | 
|  |  | 
|  | self.assertEqual(scalar.rename(None).names, []) | 
|  | self.assertEqual(scalar.rename('...').names, []) | 
|  |  | 
|  | # Check that it works with unnamed tensors | 
|  | self.assertEqual(unnamed_tensor.rename('...').names, unnamed_tensor.names) | 
|  | self.assertEqual(unnamed_tensor.rename('...', 'H', 'W').names, | 
|  | [None, None, 'H', 'W']) | 
|  | self.assertEqual(unnamed_tensor.rename('N', '...', 'W').names, | 
|  | ['N', None, None, 'W']) | 
|  | self.assertEqual(unnamed_tensor.rename('N', 'C', '...').names, | 
|  | ['N', 'C', None, None]) | 
|  |  | 
|  | # Check that it works with named tensors | 
|  | self.assertEqual(named_tensor.rename('...').names, named_tensor.names) | 
|  | self.assertEqual(named_tensor.rename('...', 'width').names, | 
|  | ['N', 'C', 'H', 'width']) | 
|  | self.assertEqual(named_tensor.rename('batch', 'channels', '...', 'width').names, | 
|  | ['batch', 'channels', 'H', 'width']) | 
|  | self.assertEqual(named_tensor.rename('batch', '...').names, | 
|  | ['batch', 'C', 'H', 'W']) | 
|  |  | 
|  | # Test empty glob | 
|  | self.assertEqual(unnamed_tensor.rename('...', None, None, None, None).names, | 
|  | [None, None, None, None]) | 
|  | self.assertEqual(named_tensor.rename('N', 'C', 'H', '...', 'W').names, | 
|  | ['N', 'C', 'H', 'W']) | 
|  |  | 
|  | # Multiple globs throw | 
|  | with self.assertRaisesRegex(RuntimeError, 'More than one '): | 
|  | named_tensor.rename('...', 'channels', '...') | 
|  |  | 
|  | def test_rename_rename_map(self): | 
|  | scalar = torch.randn([]) | 
|  | unnamed_tensor = torch.empty(1, 1, 1, 1) | 
|  | named_tensor = torch.empty(1, 1, 1, 1, names=('N', 'C', 'H', 'W')) | 
|  |  | 
|  | with self.assertRaisesRegex(RuntimeError, "dim 'N' does not exist"): | 
|  | scalar.rename(N='batch') | 
|  | with self.assertRaisesRegex(RuntimeError, "dim 'N' does not exist"): | 
|  | unnamed_tensor.rename(N='batch') | 
|  | with self.assertRaisesRegex(RuntimeError, "dim 'B' does not exist"): | 
|  | named_tensor.rename(B='batch') | 
|  | with self.assertRaisesRegex(RuntimeError, "dim 'B' does not exist"): | 
|  | named_tensor.rename(H='height', B='batch') | 
|  |  | 
|  | self.assertEqual(named_tensor.rename(N='batch').data_ptr(), | 
|  | named_tensor.data_ptr()) | 
|  | self.assertEqual(named_tensor.rename(N='batch').names, | 
|  | ['batch', 'C', 'H', 'W']) | 
|  | self.assertEqual(named_tensor.rename(N='batch', H='height').names, | 
|  | ['batch', 'C', 'height', 'W']) | 
|  |  | 
|  | def test_set_names_property(self): | 
|  | tensor = torch.empty(1, 1, names=('N', 'C')) | 
|  |  | 
|  | tensor.names = None | 
|  | self.assertEqual(tensor.names, (None, None)) | 
|  |  | 
|  | tensor.names = ('N', 'W') | 
|  | self.assertEqual(tensor.names, ('N', 'W')) | 
|  |  | 
|  | with self.assertRaisesRegex(RuntimeError, 'Number of names'): | 
|  | tensor.names = ['N', 'C', 'W'] | 
|  | with self.assertRaisesRegex(RuntimeError, 'duplicate names'): | 
|  | tensor.names = ['N', 'N'] | 
|  |  | 
|  | def test_factory_edge_cases(self): | 
|  | for device in torch.testing.get_all_device_types(): | 
|  | self._test_factory(torch.empty, device) | 
|  |  | 
|  | def test_factory_coverage(self): | 
|  | def _test(factory, device): | 
|  | names = ('N', 'T', 'D') | 
|  |  | 
|  | torch.manual_seed(0) | 
|  | result = factory(1, 2, 3, names=names, device=device) | 
|  |  | 
|  | torch.manual_seed(0) | 
|  | expected = factory(1, 2, 3, device=device).rename_(*names) | 
|  |  | 
|  | self.assertTensorDataAndNamesEqual(result, expected) | 
|  |  | 
|  | supported = [ | 
|  | torch.ones, | 
|  | torch.rand, | 
|  | torch.randn, | 
|  | torch.zeros, | 
|  | ] | 
|  |  | 
|  | for op, device in itertools.product(supported, torch.testing.get_all_device_types()): | 
|  | _test(op, device) | 
|  |  | 
|  | # Test torch.full | 
|  | for device in torch.testing.get_all_device_types(): | 
|  | names = ('N', 'T', 'D') | 
|  | result = torch.full([1, 2, 3], 2., names=names, device=device) | 
|  | expected = torch.full([1, 2, 3], 2., device=device).rename_(*names) | 
|  | self.assertTensorDataAndNamesEqual(result, expected) | 
|  |  | 
|  | def test_tensor_from_lists(self): | 
|  | names = ('N', 'C') | 
|  | tensor = torch.tensor([[1]], names=names) | 
|  | self.assertEqual(tensor.names, names) | 
|  |  | 
|  | names = ('N',) | 
|  | tensor = torch.tensor([1], names=names) | 
|  | self.assertEqual(tensor.names, names) | 
|  |  | 
|  | with self.assertRaisesRegex(RuntimeError, 'Number of names'): | 
|  | names = ('N', 'C') | 
|  | tensor = torch.tensor([1], names=names) | 
|  |  | 
|  | @unittest.skipIf(not TEST_NUMPY, "no numpy") | 
|  | def test_tensor_from_numpy(self): | 
|  | import numpy as np | 
|  | arr = np.array([[1]]) | 
|  | names = ('N', 'C') | 
|  | tensor = torch.tensor([[1]], names=names) | 
|  | self.assertEqual(tensor.names, names) | 
|  |  | 
|  | def test_tensor_from_tensor(self): | 
|  | x = torch.randn(1, 1) | 
|  | names = ('N', 'C') | 
|  | tensor = torch.tensor(x, names=names) | 
|  | self.assertEqual(tensor.names, names) | 
|  |  | 
|  | def test_tensor_from_named_tensor(self): | 
|  | x = torch.randn(1, 1, names=('N', 'D')) | 
|  | tensor = torch.tensor(x) | 
|  | self.assertEqual(tensor.names, ('N', 'D')) | 
|  |  | 
|  | # there's no way to distinguish between names=None and not passing in names. | 
|  | # If the user passes in names=None they are asking for trouble. | 
|  | x = torch.randn(1, 1, names=('N', 'D')) | 
|  | tensor = torch.tensor(x, names=None) | 
|  | self.assertEqual(tensor.names, ('N', 'D')) | 
|  |  | 
|  | x = torch.randn(1, 1, names=('N', 'D')) | 
|  | with self.assertRaisesRegex(RuntimeError, "Name mismatch"): | 
|  | tensor = torch.tensor(x, names=('N', 'C')) | 
|  |  | 
|  | def test_size(self): | 
|  | t = torch.empty(2, 3, 5, names=('N', None, 'C')) | 
|  | self.assertEqual(t.size('N'), 2) | 
|  | self.assertEqual(t.size('C'), 5) | 
|  | with self.assertRaisesRegex(RuntimeError, 'Please look up dimensions by name*'): | 
|  | t.size(None) | 
|  | with self.assertRaisesRegex(RuntimeError, 'Name \'channels\' not found in '): | 
|  | t.size('channels') | 
|  | with self.assertRaisesRegex(RuntimeError, 'Name \'N\' not found in '): | 
|  | torch.empty(2, 3, 4).size('N') | 
|  |  | 
|  | def test_stride(self): | 
|  | t = torch.empty(2, 3, 5, names=('N', None, 'C')) | 
|  | self.assertEqual(t.stride('N'), 3 * 5) | 
|  | self.assertEqual(t.stride('C'), 1) | 
|  | with self.assertRaisesRegex(RuntimeError, 'Please look up dimensions by name'): | 
|  | t.stride(None) | 
|  | with self.assertRaisesRegex(RuntimeError, 'Name \'channels\' not found in '): | 
|  | t.stride('channels') | 
|  | with self.assertRaisesRegex(RuntimeError, 'Name \'N\' not found in '): | 
|  | torch.empty(2, 3, 4).stride('N') | 
|  |  | 
|  | def test_transpose_variants(self): | 
|  | t = torch.randn(2, 3, 5, 7, names=('N', 'C', 'H', 'W')) | 
|  | self.assertEqual(t.transpose('N', 'C').names, ['C', 'N', 'H', 'W']) | 
|  | self.assertEqual(t.transpose(1, 3).names, ['N', 'W', 'H', 'C']) | 
|  |  | 
|  | t = torch.randn(2, 3, names=('N', 'C')) | 
|  | self.assertEqual(t.t().names, ['C', 'N']) | 
|  |  | 
|  | def test_resize(self): | 
|  | for device in torch.testing.get_all_device_types(): | 
|  | named = torch.randn(2, names=('N',), device=device) | 
|  | named.resize_([2]) | 
|  | self.assertEqual(named.names, ['N']) | 
|  |  | 
|  | with self.assertRaisesRegex(RuntimeError, "Cannot resize named tensor"): | 
|  | named.resize_([3]) | 
|  |  | 
|  | other_named = torch.randn(2, names=('N',), device=device) | 
|  | named.resize_as_(other_named) | 
|  | self.assertEqual(other_named.names, ['N']) | 
|  |  | 
|  | unnamed = torch.randn(2, device=device) | 
|  | with self.assertRaisesRegex( | 
|  | RuntimeError, r'names .* are not the same as the computed output names'): | 
|  | named.resize_as_(unnamed) | 
|  |  | 
|  | unnamed = torch.randn(1, device=device) | 
|  | unnamed.resize_as_(named) | 
|  | self.assertEqual(unnamed.names, ['N']) | 
|  |  | 
|  | def test_cdist(self): | 
|  | for device in torch.testing.get_all_device_types(): | 
|  | tensor = torch.randn(3, 1, 2, 7, names=('M', 'N', 'first_group', 'features'), | 
|  | device=device) | 
|  | other = torch.randn(5, 11, 7, names=('N', 'second_group', 'features'), | 
|  | device=device) | 
|  | result = torch.cdist(tensor, other) | 
|  | self.assertEqual(result.names, ['M', 'N', 'first_group', 'second_group']) | 
|  |  | 
|  | def test_info_smoke(self): | 
|  | # Smoke test for info functions / methods / attributes on named tensors. | 
|  | tensor = torch.empty(1, 1, names=('N', 'D')) | 
|  |  | 
|  | tensor.device | 
|  | tensor.dtype | 
|  | tensor.get_device() | 
|  | tensor.is_complex() | 
|  | tensor.is_floating_point() | 
|  | tensor.is_nonzero() | 
|  | torch.is_same_size(tensor, tensor) | 
|  | torch.is_signed(tensor) | 
|  | tensor.layout | 
|  | tensor.numel() | 
|  | tensor.dim() | 
|  | tensor.element_size() | 
|  | tensor.is_contiguous() | 
|  | tensor.is_cuda | 
|  | tensor.is_leaf | 
|  | tensor.is_pinned() | 
|  | tensor.is_shared() | 
|  | tensor.is_sparse | 
|  | tensor.ndimension() | 
|  | tensor.nelement() | 
|  | tensor.shape | 
|  | tensor.size() | 
|  | tensor.size(1) | 
|  | tensor.storage() | 
|  | tensor.storage_offset() | 
|  | tensor.storage_type() | 
|  | tensor.stride() | 
|  | tensor.stride(1) | 
|  | tensor.data | 
|  | tensor.data_ptr() | 
|  | tensor.ndim | 
|  | tensor.item() | 
|  | tensor.type() | 
|  | tensor.is_shared() | 
|  | tensor.is_signed() | 
|  |  | 
|  | def test_autograd_smoke(self): | 
|  | x = torch.randn(3, 3, names=('N', 'D'), requires_grad=True) | 
|  |  | 
|  | y = x.clone() | 
|  | y.retain_grad() | 
|  | y.register_hook(lambda x: x) | 
|  |  | 
|  | y.sum().backward() | 
|  |  | 
|  | # autograd related attributes | 
|  | tensor = torch.empty(1, 1, names=('N', 'D'), requires_grad=True) | 
|  | tensor = tensor.relu() | 
|  | tensor.output_nr | 
|  | tensor.grad_fn | 
|  | tensor.requires_grad | 
|  |  | 
|  | def test_split_fns_propagates_names(self): | 
|  | fns = [ | 
|  | lambda x: x.split(1, 0), | 
|  | lambda x: x.split([1, 1], 1), | 
|  | lambda x: x.chunk(2, 0), | 
|  | ] | 
|  |  | 
|  | for device in torch.testing.get_all_device_types(): | 
|  | orig_tensor = torch.empty(2, 2, names=('N', 'D'), device=device) | 
|  | for fn in fns: | 
|  | splits = fn(orig_tensor) | 
|  | for split in splits: | 
|  | self.assertEqual(split.names, orig_tensor.names) | 
|  |  | 
|  | def test_any_all(self): | 
|  | for device in torch.testing.get_all_device_types(): | 
|  | x = torch.zeros(3, dtype=torch.bool, device=device, names=('C',)) | 
|  | self.assertEqual(x.any().names, []) | 
|  | self.assertEqual(x.all().names, []) | 
|  |  | 
|  | def test_addcmul_addcdiv(self): | 
|  | for device in torch.testing.get_all_device_types(): | 
|  | names = ['N'] | 
|  | a = torch.rand(3, device=device, names=names) | 
|  | b = torch.rand(3, device=device, names=names) | 
|  | # avoid division by 0 | 
|  | c = torch.rand(3, device=device, names=names).clamp_min_(0.1) | 
|  | out = torch.randn(3, device=device, names=names) | 
|  |  | 
|  | self.assertEqual(torch.addcmul(a, b, c).names, names) | 
|  | self.assertEqual(torch.addcmul(a, b, c, out=out).names, names) | 
|  | self.assertEqual(a.addcmul_(b, c).names, names) | 
|  |  | 
|  | self.assertEqual(torch.addcdiv(a, b, c).names, names) | 
|  | self.assertEqual(torch.addcdiv(a, b, c, out=out).names, names) | 
|  | self.assertEqual(a.addcdiv_(b, c).names, names) | 
|  |  | 
|  | def test_binary_ops(self): | 
|  | def test_basic(op): | 
|  | a = torch.empty(2, 3, names=('N', 'C')) | 
|  | b = torch.empty(3, 2, names=('C', 'N')) | 
|  | c = torch.empty(3, names=('C',)) | 
|  | d = torch.empty(5, names=('W',)) | 
|  |  | 
|  | self.assertEqual(op(a, a).names, ('N', 'C')) | 
|  | self.assertEqual(op(a, c).names, ('N', 'C')) | 
|  |  | 
|  | with self.assertRaisesRegex(RuntimeError, "do not match"): | 
|  | op(a, d) | 
|  | with self.assertRaisesRegex(RuntimeError, "do not match"): | 
|  | op(a, b) | 
|  |  | 
|  | def test_wildcard(op): | 
|  | a = torch.empty(2, 3, names=('N', 'C')) | 
|  | c = torch.empty(2, 3, names=(None, 'C')) | 
|  | self.assertEqual(op(a, c).names, ('N', 'C')) | 
|  |  | 
|  | b = torch.empty(2, 3) | 
|  | self.assertEqual(op(a, b).names, ('N', 'C')) | 
|  |  | 
|  | d = torch.empty(2, 3, names=('C', None)) | 
|  | with self.assertRaisesRegex(RuntimeError, "Misaligned"): | 
|  | op(d, c) | 
|  |  | 
|  | def test_mixed_unnamed_named(op, is_inplace): | 
|  | named2 = torch.randn(1, 1, names=('N', 'C')) | 
|  | unnamed1 = torch.randn(1) | 
|  | unnamed2 = torch.randn(1, 1) | 
|  | unnamed3 = torch.randn(1, 1, 1) | 
|  |  | 
|  | def compute_expected_names(tensor, other): | 
|  | assert tensor.has_names() ^ other.has_names() | 
|  | named = tensor if tensor.has_names() else other | 
|  | unnamed = other if tensor.has_names() else tensor | 
|  | unnamed_dim = unnamed.dim() | 
|  | if unnamed_dim > named.dim(): | 
|  | return [None] * (unnamed_dim - named.dim()) + list(named.names) | 
|  | else: | 
|  | return named.names | 
|  |  | 
|  | inputs = itertools.chain( | 
|  | itertools.product([named2], [unnamed1, unnamed2, unnamed3]), | 
|  | itertools.product([unnamed1, unnamed2, unnamed3], [named2]), | 
|  | ) | 
|  | if is_inplace: | 
|  | # In-place ops have the constraint that they must not change shape. | 
|  | inputs = [(a, b) for (a, b) in inputs if a.dim() >= b.dim()] | 
|  |  | 
|  | for tensor, other in inputs: | 
|  | expected_names = compute_expected_names(tensor, other) | 
|  | self.assertEqual(op(tensor, other).names, expected_names) | 
|  |  | 
|  | def method(name, *args, **kwargs): | 
|  | return [Function(name, lambda a, b: getattr(a, name)(b, *args, **kwargs))] | 
|  |  | 
|  | def function(name, *args, **kwargs): | 
|  | return [Function(name, lambda a, b: getattr(torch, name)(a, b, *args, **kwargs))] | 
|  |  | 
|  | def out_function(name, *args, **kwargs): | 
|  | out_fn = getattr(torch, name) | 
|  |  | 
|  | def fn(a, b): | 
|  | result = torch.empty([0], dtype=a.dtype, device=a.device) | 
|  | out_fn(a, b, *args, out=result, **kwargs) | 
|  | return result | 
|  |  | 
|  | return [Function(name, fn)] | 
|  |  | 
|  | def fn_method_and_inplace(name, *args, **kwargs): | 
|  | return ( | 
|  | method(name, *args, **kwargs) + | 
|  | method(name + '_', *args, **kwargs) + | 
|  | out_function(name, *args, **kwargs) | 
|  | ) | 
|  |  | 
|  | tests = [ | 
|  | fn_method_and_inplace('add'), | 
|  | fn_method_and_inplace('div'), | 
|  | fn_method_and_inplace('mul'), | 
|  | fn_method_and_inplace('sub'), | 
|  | fn_method_and_inplace('pow'), | 
|  | fn_method_and_inplace('atan2'), | 
|  | method('copy_'), | 
|  | function('floor_divide'), | 
|  | function('true_divide'), | 
|  | ] | 
|  | tests = flatten(tests) | 
|  |  | 
|  | for name, op in tests: | 
|  | test_basic(op) | 
|  | test_wildcard(op) | 
|  | test_mixed_unnamed_named(op, is_inplace=name.endswith('_')) | 
|  |  | 
|  | def test_logical_ops(self): | 
|  | # Implemented via TensorIterator, so just check that each version | 
|  | # (out-of-place, inplace, out=) propagates names. | 
|  | def zeros(*args, **kwargs): | 
|  | return torch.zeros(*args, dtype=torch.bool, **kwargs) | 
|  |  | 
|  | for op in ('logical_xor', 'logical_and', 'logical_or'): | 
|  | self._test_name_inference( | 
|  | getattr(torch, op), | 
|  | (create('N:2,C:3', zeros), create('N:2,C:3', zeros)), | 
|  | expected_names=['N', 'C']) | 
|  |  | 
|  | self._test_name_inference( | 
|  | getattr(Tensor, op + '_'), | 
|  | (create('N:2,C:3', zeros), create('N:2,C:3', zeros)), | 
|  | expected_names=['N', 'C']) | 
|  |  | 
|  | self._test_name_inference( | 
|  | lambda out, x, y: getattr(torch, op)(x, y, out=out), | 
|  | (create('0', zeros), create('N:2,C:3', zeros), create('N:2,C:3', zeros)), | 
|  | expected_names=['N', 'C']) | 
|  |  | 
|  | def test_pow_special(self): | 
|  | # There are a few pow cases that don't go through TensorIterator. | 
|  | # Test them here. | 
|  | for device in torch.testing.get_all_device_types(): | 
|  | named = torch.randn(2, 3, names=('N', 'C'), device=device) | 
|  | unnamed = torch.randn([0], device=device) | 
|  |  | 
|  | result = torch.pow(named, 0, out=unnamed.clone()) | 
|  | self.assertEqual(result.names, named.names) | 
|  |  | 
|  | result = torch.pow(named, 1, out=unnamed.clone()) | 
|  | self.assertEqual(result.names, named.names) | 
|  |  | 
|  | result = torch.pow(1, named, out=unnamed.clone()) | 
|  | self.assertEqual(result.names, named.names) | 
|  |  | 
|  | def test_out_fn_semantics(self): | 
|  | out_fn = torch.abs | 
|  | unnamed_tensor = torch.randn(3, 2) | 
|  | none_named_tensor = torch.randn(3, 2, names=(None, None)) | 
|  | named_tensor = torch.randn(3, 2, names=('N', 'C')) | 
|  | partially_named_tensor = torch.randn(3, 2, names=('N', None)) | 
|  |  | 
|  | with self.assertRaisesRegex(RuntimeError, "Name mismatch"): | 
|  | out_fn(partially_named_tensor, out=named_tensor) | 
|  | with self.assertRaisesRegex(RuntimeError, "Name mismatch"): | 
|  | out_fn(named_tensor, out=partially_named_tensor) | 
|  | with self.assertRaisesRegex(RuntimeError, "Name mismatch"): | 
|  | out_fn(none_named_tensor, out=named_tensor) | 
|  | with self.assertRaisesRegex(RuntimeError, "Name mismatch"): | 
|  | out_fn(unnamed_tensor, out=named_tensor) | 
|  |  | 
|  | output = torch.randn(3, 2) | 
|  | out_fn(unnamed_tensor, out=output) | 
|  | self.assertFalse(output.has_names()) | 
|  |  | 
|  | output = torch.randn(3, 2, names=(None, None)) | 
|  | out_fn(named_tensor, out=output) | 
|  | self.assertEqual(output.names, named_tensor.names) | 
|  |  | 
|  | output = torch.randn(3, 2) | 
|  | out_fn(named_tensor, out=output) | 
|  | self.assertEqual(output.names, named_tensor.names) | 
|  |  | 
|  | output = torch.randn(3, 2, names=(None, None)) | 
|  | out_fn(unnamed_tensor, out=output) | 
|  | self.assertFalse(output.has_names()) | 
|  |  | 
|  | def test_unary_propagate_names_fns(self): | 
|  | def _test(testcase, names=('N', 'D'), device='cpu'): | 
|  | sizes = [2] * len(names) | 
|  | tensor = torch.empty(sizes, names=names, device=device) | 
|  | try: | 
|  | out = testcase.lambd(tensor) | 
|  | except RuntimeError as err: | 
|  | # Get a better error message by catching the error and asserting. | 
|  | raise RuntimeError('{}: {}'.format(testcase.name, err)) from err | 
|  | self.assertEqual(out.names, tensor.names, | 
|  | msg=testcase.name) | 
|  |  | 
|  | def fn(name, *args, **kwargs): | 
|  | return [Function(name, lambda t: getattr(torch, name)(t, *args, **kwargs))] | 
|  |  | 
|  | def method(name, *args, **kwargs): | 
|  | return [Function(name, lambda t: getattr(t, name)(*args, **kwargs))] | 
|  |  | 
|  | def out_function(name, *args, **kwargs): | 
|  | out_fn = getattr(torch, name) | 
|  |  | 
|  | def fn(tensor): | 
|  | result = torch.empty([0], dtype=tensor.dtype, device=tensor.device) | 
|  | out_fn(tensor, *args, out=result, **kwargs) | 
|  | return result | 
|  |  | 
|  | return [Function(name + '_out', fn)] | 
|  |  | 
|  | def fn_method_and_inplace(name, *args, **kwargs): | 
|  | return ( | 
|  | method(name, *args, **kwargs) + | 
|  | method(name + '_', *args, **kwargs) + | 
|  | out_function(name, *args, **kwargs) | 
|  | ) | 
|  |  | 
|  | # All of these operate on 2x2 tensors. | 
|  | tests = [ | 
|  | # unary pointwise | 
|  | fn_method_and_inplace('abs'), | 
|  | fn_method_and_inplace('acos'), | 
|  | fn_method_and_inplace('asin'), | 
|  | fn_method_and_inplace('atan'), | 
|  | fn_method_and_inplace('ceil'), | 
|  | fn_method_and_inplace('clamp', -1, 1), | 
|  | fn_method_and_inplace('clamp_min', -2), | 
|  | fn_method_and_inplace('clamp_max', 2), | 
|  | method('cauchy_'), | 
|  | method('clone'), | 
|  | method('contiguous'), | 
|  | fn_method_and_inplace('cos'), | 
|  | fn_method_and_inplace('cosh'), | 
|  | fn_method_and_inplace('digamma'), | 
|  | fn_method_and_inplace('erf'), | 
|  | fn_method_and_inplace('erfc'), | 
|  | fn_method_and_inplace('erfinv'), | 
|  | fn_method_and_inplace('exp'), | 
|  | fn_method_and_inplace('expm1'), | 
|  | method('exponential_'), | 
|  | fn_method_and_inplace('floor'), | 
|  | fn_method_and_inplace('frac'), | 
|  | method('geometric_', p=0.5), | 
|  | fn_method_and_inplace('lgamma'), | 
|  | fn_method_and_inplace('log'), | 
|  | fn_method_and_inplace('log10'), | 
|  | fn_method_and_inplace('log1p'), | 
|  | fn_method_and_inplace('log2'), | 
|  | method('log_normal_'), | 
|  | fn_method_and_inplace('neg'), | 
|  | method('normal_'), | 
|  | [Function('polygamma', lambda t: torch.polygamma(1, t))], | 
|  | method('polygamma_', 1), | 
|  | fn_method_and_inplace('reciprocal'), | 
|  | method('random_', 0, 1), | 
|  | method('random_', 1), | 
|  | method('random_'), | 
|  | method('relu_'), | 
|  | method('requires_grad_'), | 
|  | method('relu'), | 
|  | fn_method_and_inplace('round'), | 
|  | fn_method_and_inplace('rsqrt'), | 
|  | fn_method_and_inplace('sigmoid'), | 
|  | fn_method_and_inplace('sign'), | 
|  | fn_method_and_inplace('sin'), | 
|  | fn_method_and_inplace('sinh'), | 
|  | fn_method_and_inplace('sqrt'), | 
|  | fn_method_and_inplace('tan'), | 
|  | fn_method_and_inplace('tanh'), | 
|  | fn('threshold', 0, 1), | 
|  | fn('threshold_', 0, 1), | 
|  | out_function('threshold', 0, 1), | 
|  | fn_method_and_inplace('trunc'), | 
|  | method('uniform_'), | 
|  | method('zero_'), | 
|  | method('fill_', 1), | 
|  | method('fill_', torch.tensor(3.14)), | 
|  |  | 
|  | # conversions | 
|  | method('to', dtype=torch.long), | 
|  | method('to', device='cpu'), | 
|  | method('to', torch.empty([])), | 
|  | method('bool'), | 
|  | method('byte'), | 
|  | method('char'), | 
|  | method('cpu'), | 
|  | method('double'), | 
|  | method('float'), | 
|  | method('long'), | 
|  | method('half'), | 
|  | method('int'), | 
|  | method('short'), | 
|  | method('type', dtype=torch.long), | 
|  |  | 
|  | # cumsum and cumprod | 
|  | fn('cumsum', 0), | 
|  | fn('cumsum', 'D'), | 
|  | out_function('cumsum', 'D'), | 
|  | fn('cumprod', 0), | 
|  | fn('cumprod', 'D'), | 
|  | out_function('cumprod', 'D'), | 
|  |  | 
|  | # views | 
|  | method('narrow', 0, 0, 1), | 
|  |  | 
|  | # creation functions | 
|  | fn('empty_like'), | 
|  | fn('zeros_like'), | 
|  | fn('ones_like'), | 
|  | fn('full_like', 3.14), | 
|  | fn('rand_like'), | 
|  | fn('randn_like'), | 
|  |  | 
|  | # bernoulli variants | 
|  | method('bernoulli_', 0.5), | 
|  | method('bernoulli_', torch.tensor(0.5)), | 
|  |  | 
|  | method('softmax', dim=1), | 
|  | method('softmax', dim='D'), | 
|  | method('log_softmax', dim=1), | 
|  | method('log_softmax', dim='D'), | 
|  |  | 
|  | [Function('F.dropout(inplace)', lambda t: F.dropout(t, p=0.5, inplace=True))], | 
|  | [Function('F.dropout(outplace)', lambda t: F.dropout(t, p=0.5, inplace=False))], | 
|  | ] | 
|  | tests = flatten(tests) | 
|  |  | 
|  | for testcase, device in itertools.product(tests, torch.testing.get_all_device_types()): | 
|  | _test(testcase, device=device) | 
|  |  | 
|  | def test_cummax_cummin(self): | 
|  | def test_ops(op): | 
|  | for device in torch.testing.get_all_device_types(): | 
|  | names = ('N', 'D') | 
|  | tensor = torch.rand(2, 3, names=names) | 
|  | result = op(tensor, 0) | 
|  | self.assertEqual(result[0].names, names) | 
|  | self.assertEqual(result[1].names, names) | 
|  | test_ops(torch.cummax) | 
|  | test_ops(torch.cummin) | 
|  |  | 
|  | def test_logcumsumexp(self): | 
|  | for device in torch.testing.get_all_device_types(): | 
|  | names = ('N', 'D') | 
|  | tensor = torch.rand(2, 3, names=names) | 
|  | result = torch.logcumsumexp(tensor, 'D') | 
|  | self.assertEqual(result.names, names) | 
|  |  | 
|  | def test_bitwise_not(self): | 
|  | for device in torch.testing.get_all_device_types(): | 
|  | names = ('N', 'D') | 
|  | tensor = torch.zeros(2, 3, names=names, dtype=torch.bool) | 
|  | result = torch.empty(0, dtype=torch.bool) | 
|  |  | 
|  | self.assertEqual(tensor.bitwise_not().names, names) | 
|  | self.assertEqual(torch.bitwise_not(tensor, out=result).names, names) | 
|  | self.assertEqual(tensor.bitwise_not_().names, names) | 
|  |  | 
|  | def test_logical_not(self): | 
|  | for device in torch.testing.get_all_device_types(): | 
|  | names = ('N', 'D') | 
|  | tensor = torch.zeros(2, 3, names=names, dtype=torch.bool) | 
|  | result = torch.empty(0, dtype=torch.bool) | 
|  |  | 
|  | self.assertEqual(tensor.logical_not().names, names) | 
|  | self.assertEqual(torch.logical_not(tensor, out=result).names, names) | 
|  | self.assertEqual(tensor.logical_not_().names, names) | 
|  |  | 
|  | def test_bernoulli(self): | 
|  | for device in torch.testing.get_all_device_types(): | 
|  | names = ('N', 'D') | 
|  | tensor = torch.rand(2, 3, names=names) | 
|  | result = torch.empty(0) | 
|  | self.assertEqual(tensor.bernoulli().names, names) | 
|  |  | 
|  | torch.bernoulli(tensor, out=result) | 
|  | self.assertEqual(result.names, names) | 
|  |  | 
|  | def test_flatten(self): | 
|  | tensor = torch.randn(2, 3, 5, 7, 11, names=('N', 'C', 'D', 'H', 'W')) | 
|  |  | 
|  | # basic | 
|  | out = tensor.flatten('D', 'W', 'features') | 
|  | self.assertEqual(out.names, ['N', 'C', 'features']) | 
|  | self.assertEqual(out.rename(None), tensor.rename(None).view(2, 3, -1)) | 
|  |  | 
|  | # int overload | 
|  | out = tensor.flatten(2, 4, 'features') | 
|  | self.assertEqual(out.names, ['N', 'C', 'features']) | 
|  | self.assertEqual(out.rename(None), tensor.rename(None).view(2, 3, -1)) | 
|  |  | 
|  | # list overload | 
|  | out = tensor.flatten(['D', 'H', 'W'], 'features') | 
|  | self.assertEqual(out.names, ['N', 'C', 'features']) | 
|  | self.assertEqual(out.rename(None), tensor.rename(None).view(2, 3, -1)) | 
|  |  | 
|  | # Non-contiguous flatten: N and H are not "adjacent" in memory. | 
|  | sentences = torch.randn(2, 3, 5, 7, names=('N', 'T', 'H', 'D')) | 
|  | sentences = sentences.transpose('T', 'H') | 
|  | out = sentences.flatten('N', 'H', 'N_H') | 
|  | self.assertEqual(out.names, ['N_H', 'T', 'D']) | 
|  |  | 
|  | with self.assertRaisesRegex(RuntimeError, "Name 'L' not found in"): | 
|  | tensor.flatten(['D', 'L'], 'features') | 
|  |  | 
|  | with self.assertRaisesRegex(RuntimeError, "must be consecutive in"): | 
|  | tensor.flatten(['D', 'W'], 'features') | 
|  |  | 
|  | with self.assertRaisesRegex(RuntimeError, "must be consecutive in"): | 
|  | tensor.flatten(['H', 'D', 'W'], 'features') | 
|  |  | 
|  | def test_unflatten(self): | 
|  | # test args: tensor, int, namedshape | 
|  | self.assertTrue(torch.equal( | 
|  | torch.ones(4).unflatten(0, (('A', 2), ('B', 2))), | 
|  | torch.ones(2, 2, names=('A', 'B')))) | 
|  | self.assertTrue(torch.equal( | 
|  | torch.ones(4).unflatten(0, [('A', 2), ('B', 2)]), | 
|  | torch.ones(2, 2, names=('A', 'B')))) | 
|  | self.assertTrue(torch.equal( | 
|  | torch.ones(4).unflatten(0, (['A', 2], ['B', 2])), | 
|  | torch.ones(2, 2, names=('A', 'B')))) | 
|  | self.assertTrue(torch.equal( | 
|  | torch.ones(4).unflatten(-1, (['A', 2], ['B', 2])), | 
|  | torch.ones(2, 2, names=('A', 'B')))) | 
|  | self.assertTrue(torch.equal( | 
|  | torch.ones(4).unflatten(-1, (['A', -1], ['B', 2])), | 
|  | torch.ones(2, 2, names=('A', 'B')))) | 
|  | self.assertTrue(torch.equal( | 
|  | torch.ones(4).unflatten(-1, (['A', 2], ['B', -1])), | 
|  | torch.ones(2, 2, names=('A', 'B')))) | 
|  | self.assertTrue(torch.equal( | 
|  | torch.ones(2, 10, names=('A', 'B')).unflatten('B', (['B1', -1],)), | 
|  | torch.ones(2, 10, names=('A', 'B1')))) | 
|  | self.assertTrue(torch.equal( | 
|  | torch.ones(2, 3 * 4 * 5 * 6, names=('A', 'B')) | 
|  | .unflatten('B', (['B1', 3], ['B2', 4], ['B3', -1], ['B4', 6])), | 
|  | torch.ones(2, 3, 4, 5, 6, names=('A', 'B1', 'B2', 'B3', 'B4')))) | 
|  | self.assertTrue(torch.equal( | 
|  | torch.ones(2, 0, names=('A', 'B')) | 
|  | .unflatten('B', (['B1', 3], ['B2', -1], ['B3', 4])), | 
|  | torch.ones(2, 3, 0, 4, names=('A', 'B1', 'B2', 'B3')))) | 
|  |  | 
|  | # test args: namedtensor, int, namedshape | 
|  | self.assertTrue(torch.equal( | 
|  | torch.ones(2, 4, names=('A', 'B')).unflatten(1, (('B1', 2), ('B2', 2))), | 
|  | torch.ones(2, 2, 2, names=('A', 'B1', 'B2')))) | 
|  |  | 
|  | # test args: namedtensor, str, namedshape | 
|  | self.assertTrue(torch.equal( | 
|  | torch.ones(2, 4, names=('A', 'B')).unflatten('B', (('B1', 2), ('B2', 2))), | 
|  | torch.ones(2, 2, 2, names=('A', 'B1', 'B2')))) | 
|  |  | 
|  | # test invalid args: namedtensor, str, sizes | 
|  | with self.assertRaisesRegex(TypeError, r"received an invalid combination of arguments"): | 
|  | torch.tensor([1], names=('A',)).unflatten('A', (1, 1)) | 
|  |  | 
|  | # test invalid args: namedtensor, int, sizes | 
|  | with self.assertRaisesRegex(RuntimeError, r"input is a named tensor but no names were given for unflattened sizes"): | 
|  | torch.tensor([1], names=("A",)).unflatten(0, (1, 1)) | 
|  |  | 
|  | with self.assertRaisesRegex(RuntimeError, | 
|  | r"Provided sizes \[3, -1\] don't multiply up to the " | 
|  | r"size of dim 1 \('B': 4\) in Tensor\['A', 'B'\]"): | 
|  | torch.ones(2, 4, names=('A', 'B')).unflatten('B', (('B1', 3), ('B2', -1))) | 
|  |  | 
|  | with self.assertRaisesRegex(RuntimeError, | 
|  | r"the unspecified dimension size -1 can be any value and is ambiguous"): | 
|  | torch.ones(2, 0, names=('A', 'B')).unflatten('B', (('B1', 0), ('B2', -1))) | 
|  |  | 
|  | tensor = torch.randn(7, 2 * 3 * 5, 11, names=('N', 'D', 'K')) | 
|  |  | 
|  | # accepts OrderedDict | 
|  | out = tensor.unflatten('D', OrderedDict((('C', 2), ('H', 3), ('W', 5)))) | 
|  | self.assertEqual(out.names, ('N', 'C', 'H', 'W', 'K')) | 
|  | self.assertEqual(out.shape, (7, 2, 3, 5, 11)) | 
|  |  | 
|  | # Unflatten left-most | 
|  | out = tensor.unflatten('N', (('N', 7), ('H', 1))) | 
|  | self.assertEqual(out.names, ('N', 'H', 'D', 'K')) | 
|  | self.assertEqual(out.shape, (7, 1, 2 * 3 * 5, 11)) | 
|  |  | 
|  | # Unflatten right-most | 
|  | out = tensor.unflatten('K', (('K', 11), ('H', 1))) | 
|  | self.assertEqual(out.names, ('N', 'D', 'K', 'H')) | 
|  | self.assertEqual(out.shape, (7, 2 * 3 * 5, 11, 1)) | 
|  |  | 
|  | with self.assertRaisesRegex(RuntimeError, "don't multiply up to"): | 
|  | tensor.unflatten('D', (('H', 3), ('W', 5))) | 
|  |  | 
|  | with self.assertRaisesRegex(RuntimeError, 'sizes must be non-empty'): | 
|  | tensor.unflatten('D', None) | 
|  |  | 
|  | with self.assertRaisesRegex(RuntimeError, 'non-empty'): | 
|  | tensor.unflatten('D', OrderedDict()) | 
|  |  | 
|  | def test_unsupported_op_error_msg(self): | 
|  | named = torch.randn(3, 3, names=('N', 'C')) | 
|  | with self.assertRaisesRegex( | 
|  | RuntimeError, r"pdist.+is not yet supported with named tensors"): | 
|  | torch.pdist(named) | 
|  | with self.assertRaisesRegex( | 
|  | RuntimeError, r"as_strided_.+is not yet supported with named tensors"): | 
|  | named.as_strided_((3, 3), (3, 1)) | 
|  |  | 
|  | def test_reduction_fns(self): | 
|  | def check_output(output, expected_names): | 
|  | if isinstance(output, torch.Tensor): | 
|  | self.assertEqual(output.names, expected_names) | 
|  | return | 
|  | for out in output: | 
|  | self.assertEqual(out.names, expected_names) | 
|  |  | 
|  | def sum_all_outputs(output): | 
|  | if isinstance(output, torch.Tensor): | 
|  | return output.sum() | 
|  | result = 0 | 
|  | for out in output: | 
|  | result = out + result | 
|  | return result.sum() | 
|  |  | 
|  | def test_simple_reduce(op, device): | 
|  | t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device) | 
|  | check_output(op(t, 1), ['N', 'L']) | 
|  | check_output(op(t, -1), ['N', 'C']) | 
|  | check_output(op(t, 'C'), ['N', 'L']) | 
|  | with self.assertRaisesRegex(RuntimeError, 'Please look up dimensions by name'): | 
|  | op(t, None) | 
|  | with self.assertRaisesRegex(RuntimeError, 'Name \'H\' not found'): | 
|  | op(t, 'H') | 
|  |  | 
|  | def test_autograd_supports_dimname_overload(op, device): | 
|  | t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device, requires_grad=True) | 
|  | sum_all_outputs(op(t, 'C')).backward() | 
|  | self.assertIsNotNone(t.grad) | 
|  |  | 
|  | def test_complete_reduce(op, device): | 
|  | t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device) | 
|  | check_output(op(t), []) | 
|  |  | 
|  | def test_multidim_reduce(op, device): | 
|  | t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device) | 
|  |  | 
|  | check_output(op(t, [1, 2]), ['N']) | 
|  | check_output(op(t, [0, -1]), ['C']) | 
|  | check_output(op(t, ['C', 'L']), ['N']) | 
|  | with self.assertRaisesRegex(RuntimeError, 'Please look up dimensions by name'): | 
|  | op(t, [None, 'C']) | 
|  |  | 
|  | def test_out_variant(op, output_lambda, device): | 
|  | t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device) | 
|  | if output_lambda: | 
|  | out = output_lambda(t) | 
|  | else: | 
|  | out = torch.empty([0], device=device) | 
|  | op(t, 'C', out=out) | 
|  | check_output(out, ['N', 'L']) | 
|  |  | 
|  | def test_keepdim(op, device): | 
|  | t = torch.empty(2, 3, 5, names=('N', 'C', 'L'), device=device) | 
|  | check_output(op(t, 'C', keepdim=True), ['N', 'C', 'L']) | 
|  |  | 
|  | def values_and_indices(t): | 
|  | return (torch.empty([0], device=t.device), | 
|  | torch.empty([0], device=t.device, dtype=torch.long)) | 
|  |  | 
|  | def kthvalue_wrapper(tensor, *args, **kwargs): | 
|  | # Return the 0-th value | 
|  | return torch.kthvalue(tensor, 1, *args, **kwargs) | 
|  |  | 
|  | Case = namedtuple('Case', [ | 
|  | 'op', | 
|  | 'supports_complete_reduce', | 
|  | 'supports_multidim_reduce', | 
|  | 'supports_out_variant', | 
|  | 'supports_keepdim', | 
|  | 'output_lambda', | 
|  | ]) | 
|  |  | 
|  | tests = [ | 
|  | Case(torch.sum, True, True, True, True, None), | 
|  | Case(torch.prod, True, False, True, True, None), | 
|  | Case(torch.mean, True, True, True, True, None), | 
|  | Case(torch.var, True, True, True, True, None), | 
|  | Case(torch.std, True, True, True, True, None), | 
|  | Case(torch.std_mean, True, True, False, True, None), | 
|  | Case(torch.var_mean, True, True, False, True, None), | 
|  | Case(torch.min, True, False, True, True, values_and_indices), | 
|  | Case(torch.max, True, False, True, True, values_and_indices), | 
|  | Case(torch.unbind, False, False, False, False, None), | 
|  | Case(torch.logsumexp, False, True, True, True, None), | 
|  | Case(torch.mode, False, False, True, True, values_and_indices), | 
|  | Case(kthvalue_wrapper, False, False, True, True, values_and_indices), | 
|  | Case(torch.median, True, False, True, True, values_and_indices), | 
|  | Case(torch.nanmedian, True, False, True, True, values_and_indices), | 
|  | ] | 
|  |  | 
|  | for testcase, device in itertools.product(tests, torch.testing.get_all_device_types()): | 
|  | op = testcase.op | 
|  | test_simple_reduce(op, device) | 
|  | test_autograd_supports_dimname_overload(op, device) | 
|  |  | 
|  | if testcase.supports_keepdim: | 
|  | test_keepdim(op, device) | 
|  | if testcase.supports_out_variant: | 
|  | test_out_variant(op, testcase.output_lambda, device) | 
|  | if testcase.supports_complete_reduce: | 
|  | test_complete_reduce(op, device) | 
|  | if testcase.supports_multidim_reduce: | 
|  | test_multidim_reduce(op, device) | 
|  |  | 
|  | def test_masked_select(self): | 
|  | # simple | 
|  | self._test_name_inference( | 
|  | torch.masked_select, | 
|  | (create('N:2,C:3'), (create('2,3') > 0).rename('N', 'C')), | 
|  | expected_names=[None]) | 
|  |  | 
|  | # left broadcast | 
|  | self._test_name_inference( | 
|  | torch.masked_select, | 
|  | (create('C:3'), (create('2,3') > 0).rename('N', 'C')), | 
|  | expected_names=[None]) | 
|  |  | 
|  | # right broadcast | 
|  | self._test_name_inference( | 
|  | torch.masked_select, | 
|  | (create('N:2,C:3'), (create('3') > 0).rename('C')), | 
|  | expected_names=[None]) | 
|  |  | 
|  | # error | 
|  | self._test_name_inference( | 
|  | torch.masked_select, | 
|  | (create('N:2,C:3'), (create('3') > 0).rename('D')), | 
|  | maybe_raises_regex='do not match') | 
|  |  | 
|  | # out= | 
|  | self._test_name_inference( | 
|  | out_fn(torch.masked_select), | 
|  | (create('0'), create('N:2,C:3'), (create('2,3') > 0).rename('N', 'C')), | 
|  | expected_names=[None]) | 
|  |  | 
|  | def test_cat(self): | 
|  | # simple | 
|  | self._test_name_inference( | 
|  | torch.cat, | 
|  | [[create('N:2,C:3'), create('N:2,C:3')]], | 
|  | expected_names=['N', 'C']) | 
|  |  | 
|  | # error: zero dim | 
|  | self._test_name_inference( | 
|  | torch.cat, | 
|  | [[create(''), create('')]], | 
|  | maybe_raises_regex='zero-dim') | 
|  |  | 
|  | # error: names don't match | 
|  | self._test_name_inference( | 
|  | torch.cat, | 
|  | [[create('N:2,C:3'), create('C:3,N:2')]], | 
|  | maybe_raises_regex='do not match') | 
|  |  | 
|  | # error: different number of dims | 
|  | self._test_name_inference( | 
|  | torch.cat, | 
|  | [[create('N:2,C:3'), create('C:3')]], | 
|  | maybe_raises_regex='must have same number of dimensions') | 
|  |  | 
|  | # out= | 
|  | self._test_name_inference( | 
|  | out_fn(torch.cat), | 
|  | [create('0'), [create('N:2,C:3'), create('N:2,C:3')]], | 
|  | expected_names=['N', 'C']) | 
|  |  | 
|  | def test_masked_fill(self): | 
|  | # simple | 
|  | self._test_name_inference( | 
|  | Tensor.masked_fill, | 
|  | (create('N:2,C:3'), (create('2,3') > 0).rename('N', 'C'), 3.14), | 
|  | expected_names=['N', 'C']) | 
|  |  | 
|  | # left broadcast | 
|  | self._test_name_inference( | 
|  | Tensor.masked_fill, | 
|  | (create('C:3'), (create('2,3') > 0).rename('N', 'C'), 3.14), | 
|  | maybe_raises_regex="must be less than or equal to") | 
|  |  | 
|  | # right broadcast | 
|  | self._test_name_inference( | 
|  | Tensor.masked_fill, | 
|  | (create('N:2,C:3'), (create('3') > 0).rename('C'), 3.14), | 
|  | expected_names=['N', 'C']) | 
|  |  | 
|  | # error | 
|  | self._test_name_inference( | 
|  | Tensor.masked_fill, | 
|  | (create('N:2,C:3'), (create('3') > 0).rename('D'), 3.14), | 
|  | maybe_raises_regex='do not match') | 
|  |  | 
|  | # inplace | 
|  | self._test_name_inference( | 
|  | Tensor.masked_fill_, | 
|  | (create('N:2,C:3'), (create('2,3') > 0).rename('N', 'C'), 3.14), | 
|  | expected_names=['N', 'C']) | 
|  |  | 
|  | # inplace, computed names don't match output tensor names | 
|  | self._test_name_inference( | 
|  | Tensor.masked_fill_, | 
|  | (create('N:2,None:3'), (create('2,3') > 0).rename('N', 'C'), 3.14), | 
|  | maybe_raises_regex="not the same as the computed output names") | 
|  |  | 
|  |  | 
|  | def test_using_seen_interned_string_doesnt_bump_refcount(self): | 
|  | def see_name(): | 
|  | seen_name = 'N' | 
|  | pass_name_to_python_arg_parser(seen_name) | 
|  |  | 
|  | see_name() | 
|  | seen_name = 'N' | 
|  | old_refcnt = sys.getrefcount(seen_name) | 
|  |  | 
|  | pass_name_to_python_arg_parser(seen_name) | 
|  |  | 
|  | new_refcnt = sys.getrefcount(seen_name) | 
|  | self.assertEqual(new_refcnt, old_refcnt) | 
|  |  | 
|  | def test_using_unseen_interned_string_bumps_refcount_permanently(self): | 
|  | # Please don't use this as a name in a different test. | 
|  | unseen_name = 'abcdefghi' | 
|  | old_refcnt = sys.getrefcount(unseen_name) | 
|  |  | 
|  | pass_name_to_python_arg_parser(unseen_name) | 
|  |  | 
|  | new_refcnt = sys.getrefcount(unseen_name) | 
|  | self.assertEqual(new_refcnt, old_refcnt + 1) | 
|  |  | 
|  | def test_using_unseen_uninterned_string_refcounts(self): | 
|  | # Please don't use this as a name in a different test. | 
|  | # non-compile-time constants are not interned | 
|  | unseen_name = ''.join(['abc', 'def', 'ghi', 'jkl']) | 
|  | interned_unseen_name = 'abcdefghijkl' | 
|  | self.assertFalse(unseen_name is interned_unseen_name) | 
|  |  | 
|  | old_uninterned_refcnt = sys.getrefcount(unseen_name) | 
|  | old_interned_refcnt = sys.getrefcount(interned_unseen_name) | 
|  |  | 
|  | pass_name_to_python_arg_parser(unseen_name) | 
|  |  | 
|  | new_uninterned_refcnt = sys.getrefcount(unseen_name) | 
|  | new_interned_refcnt = sys.getrefcount(interned_unseen_name) | 
|  |  | 
|  | # Internally, PyTorch should not hold a reference to the uninterned string | 
|  | self.assertEqual(new_uninterned_refcnt, old_uninterned_refcnt) | 
|  |  | 
|  | # Instead, we should hold a new reference to the interned version. | 
|  | self.assertEqual(new_interned_refcnt, old_interned_refcnt + 1) | 
|  |  | 
|  | def _test_select(self, device): | 
|  | x = torch.empty(2, 3, 4, 5, names=('N', 'C', 'H', 'W'), device=device) | 
|  | y = x.select(1, 1) | 
|  | self.assertEqual(y.names, ('N', 'H', 'W')) | 
|  |  | 
|  | y = x.select('C', 1) | 
|  | self.assertEqual(y.names, ('N', 'H', 'W')) | 
|  |  | 
|  | with self.assertRaisesRegex( | 
|  | RuntimeError, 'Please look up dimensions by name'): | 
|  | y = x.select(None, 1) | 
|  |  | 
|  | def test_select(self): | 
|  | self._test_select('cpu') | 
|  |  | 
|  | @unittest.skipIf(not TEST_CUDA, 'no CUDA') | 
|  | def test_select_cuda(self): | 
|  | self._test_select('cuda') | 
|  |  | 
|  | def _test_as_strided(self, device): | 
|  | x = torch.empty(2, 3, 4, 5, names=('N', 'C', 'H', 'W'), device=device) | 
|  | y = x.as_strided([2 * 3 * 4 * 5], [1]) | 
|  | self.assertEqual(y.names, (None,)) | 
|  |  | 
|  | def test_as_strided(self): | 
|  | self._test_as_strided('cpu') | 
|  |  | 
|  | @unittest.skipIf(not TEST_CUDA, 'no CUDA') | 
|  | def test_as_strided_cuda(self): | 
|  | self._test_as_strided('cuda') | 
|  |  | 
|  | def test_no_jit_tracer_support(self): | 
|  | def foo(x): | 
|  | return torch.full(x.shape, 2., names=('N',)) | 
|  |  | 
|  | with self.assertRaisesRegex(RuntimeError, 'not supported with the tracer'): | 
|  | x = torch.randn(3) | 
|  | torch.jit.trace(foo, example_inputs=x) | 
|  |  | 
|  | def bar(x): | 
|  | return x.select('N', 1) | 
|  |  | 
|  | with self.assertRaisesRegex(RuntimeError, 'not supported with the tracer'): | 
|  | x = torch.randn(3) | 
|  | torch.jit.trace(bar, example_inputs=x) | 
|  |  | 
|  | def test_no_jit_script_support(self): | 
|  | @torch.jit.script | 
|  | def foo(x): | 
|  | return x + 1 | 
|  |  | 
|  | with self.assertRaisesRegex(RuntimeError, 'NYI'): | 
|  | foo(torch.randn(2, 3, names=('N', 'C'))) | 
|  |  | 
|  | @torch.jit.ignore | 
|  | def add_names(x): | 
|  | x.names = ('N', 'C') | 
|  |  | 
|  | @torch.jit.script | 
|  | def return_named_tensor(input): | 
|  | add_names(input) | 
|  | return input | 
|  |  | 
|  | with self.assertRaisesRegex(RuntimeError, "NYI"): | 
|  | return_named_tensor(torch.randn(1, 1)) | 
|  |  | 
|  | def test_align_to(self): | 
|  | # trivial | 
|  | tensor = create('N:3') | 
|  | output = tensor.align_to('N') | 
|  | self.assertEqual(output.names, ['N']) | 
|  | self.assertEqual(output.shape, [3]) | 
|  |  | 
|  | # unsqueeze behavior | 
|  | tensor = create('N:3') | 
|  | output = tensor.align_to('N', 'D') | 
|  | self.assertEqual(output.names, ['N', 'D']) | 
|  | self.assertEqual(output.shape, [3, 1]) | 
|  |  | 
|  | # transpose behavior | 
|  | tensor = create('N:3,C:2') | 
|  | output = tensor.align_to('C', 'N') | 
|  | self.assertEqual(output.names, ['C', 'N']) | 
|  | self.assertEqual(output.shape, [2, 3]) | 
|  |  | 
|  | # unsqueeze / transpose | 
|  | tensor = create('C:2,N:3,H:5') | 
|  | output = tensor.align_to('N', 'H', 'W', 'C') | 
|  | self.assertEqual(output.names, ['N', 'H', 'W', 'C']) | 
|  | self.assertEqual(output.shape, [3, 5, 1, 2]) | 
|  |  | 
|  | # All input dimensions must be named | 
|  | with self.assertRaisesRegex(RuntimeError, "All input dims must be named. Found unnamed dim at index 0"): | 
|  | create('None:2,C:3').align_to('N', 'C') | 
|  |  | 
|  | # not enough names | 
|  | with self.assertRaisesRegex(RuntimeError, "Cannot find dim 'N'"): | 
|  | create('N:2,C:3').align_to('C') | 
|  |  | 
|  | # names not found | 
|  | with self.assertRaisesRegex(RuntimeError, "Cannot find dim 'C'"): | 
|  | create('N:2,C:3').align_to('D', 'N') | 
|  |  | 
|  | def test_align_to_ellipsis(self): | 
|  | tensor = create('N:7,H:3,W:5,C:2') | 
|  |  | 
|  | # ... = ['N', 'H', 'W', 'C'] | 
|  | output = tensor.align_to('...') | 
|  | self.assertEqual(output.names, ['N', 'H', 'W', 'C']) | 
|  | self.assertEqual(output.shape, [7, 3, 5, 2]) | 
|  |  | 
|  | # ... = ['H', 'C'] | 
|  | output = tensor.align_to('...', 'W', 'N') | 
|  | self.assertEqual(output.names, ['H', 'C', 'W', 'N']) | 
|  | self.assertEqual(output.shape, [3, 2, 5, 7]) | 
|  |  | 
|  | # ... = ['N', 'W'] | 
|  | output = tensor.align_to('H', 'C', '...') | 
|  | self.assertEqual(output.names, ['H', 'C', 'N', 'W']) | 
|  | self.assertEqual(output.shape, [3, 2, 7, 5]) | 
|  |  | 
|  | # ... = ['H', 'C'] | 
|  | output = tensor.align_to('W', '...', 'N') | 
|  | self.assertEqual(output.names, ['W', 'H', 'C', 'N']) | 
|  | self.assertEqual(output.shape, [5, 3, 2, 7]) | 
|  |  | 
|  | # ... = [] | 
|  | output = tensor.align_to('N', '...', 'C', 'D', 'H', 'W') | 
|  | self.assertEqual(output.names, ['N', 'C', 'D', 'H', 'W']) | 
|  | self.assertEqual(output.shape, [7, 2, 1, 3, 5]) | 
|  |  | 
|  | # Input tensor partially named | 
|  | partially_named = create('None:2,None:3,None:5,C:7') | 
|  | output = partially_named.align_to('C', '...') | 
|  | self.assertEqual(output.names, ['C', None, None, None]) | 
|  | self.assertEqual(output.shape, [7, 2, 3, 5]) | 
|  |  | 
|  | with self.assertRaisesRegex(RuntimeError, "order of dimensions cannot contain a None"): | 
|  | partially_named.align_to('C', None, '...') | 
|  |  | 
|  | # Input order partially named | 
|  | with self.assertRaisesRegex(RuntimeError, "cannot contain a None name"): | 
|  | tensor.align_to('...', 'N', None) | 
|  |  | 
|  | # Input order duplicate names | 
|  | with self.assertRaisesRegex(RuntimeError, "duplicate names"): | 
|  | tensor.align_to('...', 'N', 'N') | 
|  |  | 
|  | def test_align_as(self): | 
|  | # align_as calls align_to internally. align_to has pretty substantial tests, | 
|  | # so just test some basic things here. | 
|  | tensor = create('C:2,N:3,H:5') | 
|  | other = create('N:1,H:1,W:1,C:1') | 
|  | output = tensor.align_as(other) | 
|  | self.assertEqual(output.names, ['N', 'H', 'W', 'C']) | 
|  | self.assertEqual(output.shape, [3, 5, 1, 2]) | 
|  |  | 
|  | @unittest.skip("Not implemented yet") | 
|  | def test_align_tensors_two_inputs(self): | 
|  | def _test(tensor_namedshape, align_names, expected_sizes, expected_error): | 
|  | tensor_names, tensor_sizes = tensor_namedshape | 
|  | tensor = torch.empty(*tensor_sizes, names=tensor_names) | 
|  | other = torch.empty([1] * len(align_names), names=align_names) | 
|  | if expected_error is not None: | 
|  | with self.assertRaisesRegex(RuntimeError, expected_error): | 
|  | torch.align_tensors(tensor, other) | 
|  | return | 
|  |  | 
|  | output, _ = torch.align_tensors(tensor, other) | 
|  | self.assertEqual(output.shape, expected_sizes) | 
|  | self.assertEqual(output.names, align_names) | 
|  |  | 
|  | Case = namedtuple('Case', [ | 
|  | 'tensor_namedshape', | 
|  | 'align_names', | 
|  | 'expected_sizes', | 
|  | 'expected_error', | 
|  | ]) | 
|  |  | 
|  | tests = [ | 
|  | # basic tests | 
|  | Case(tensor_namedshape=(['C'], [2]), | 
|  | align_names=['C'], | 
|  | expected_sizes=[2], | 
|  | expected_error=None), | 
|  | Case(tensor_namedshape=(['C'], [2]), | 
|  | align_names=['D'], | 
|  | expected_sizes=None, | 
|  | expected_error='not a subsequence'), | 
|  |  | 
|  | # single-dim alignment test | 
|  | Case(tensor_namedshape=(['C'], [2]), | 
|  | align_names=['N', 'C'], | 
|  | expected_sizes=[1, 2], | 
|  | expected_error=None), | 
|  | Case(tensor_namedshape=[['N'], [2]], | 
|  | align_names=['N', 'C'], | 
|  | expected_sizes=[2, 1], | 
|  | expected_error=None), | 
|  |  | 
|  | # multiple dim alignment test | 
|  | Case(tensor_namedshape=[['N', 'C'], [2, 3]], | 
|  | align_names=['N', 'H', 'C', 'W'], | 
|  | expected_sizes=[2, 1, 3, 1], | 
|  | expected_error=None), | 
|  | Case(tensor_namedshape=[['N', 'C'], [2, 3]], | 
|  | align_names=['C', 'H', 'N', 'W'], | 
|  | expected_sizes=None, | 
|  | expected_error='not a subsequence'), | 
|  |  | 
|  | # scalar tensor tests | 
|  | Case(tensor_namedshape=[None, [[]]], | 
|  | align_names=['N', 'C'], | 
|  | expected_sizes=[1, 1], | 
|  | expected_error=None), | 
|  | Case(tensor_namedshape=[[], [[]]], | 
|  | align_names=[None, None], | 
|  | expected_sizes=[1, 1], | 
|  | expected_error=None), | 
|  |  | 
|  | # unnamed tensor tests | 
|  | Case(tensor_namedshape=[None, [2, 3]], | 
|  | align_names=[None, None], | 
|  | expected_sizes=[2, 3], | 
|  | expected_error=None), | 
|  | Case(tensor_namedshape=[None, [2, 3]], | 
|  | align_names=[None, None, None], | 
|  | expected_sizes=[1, 2, 3], | 
|  | expected_error=None), | 
|  | Case(tensor_namedshape=[None, [2]], | 
|  | align_names=['N'], | 
|  | expected_sizes=None, | 
|  | expected_error='not a subsequence'), | 
|  |  | 
|  | # unnamed dim alignment tests | 
|  | Case(tensor_namedshape=[[None], [2]], | 
|  | align_names=['N', None], | 
|  | expected_sizes=[1, 2], | 
|  | expected_error=None), | 
|  | Case(tensor_namedshape=[[None], [2]], | 
|  | align_names=['N', None, None, None], | 
|  | expected_sizes=[1, 1, 1, 2], | 
|  | expected_error=None), | 
|  | Case(tensor_namedshape=[['N'], [2]], | 
|  | align_names=['N', None, None, None], | 
|  | expected_sizes=[2, 1, 1, 1], | 
|  | expected_error=None), | 
|  | Case(tensor_namedshape=[[None, 'N', None], [2, 3, 5]], | 
|  | align_names=[None, None, 'N', None], | 
|  | expected_sizes=[1, 2, 3, 5], | 
|  | expected_error=None), | 
|  | Case(tensor_namedshape=[[None], [2]], | 
|  | align_names=[None, 'N'], | 
|  | expected_sizes=None, | 
|  | expected_error='absolute position from the right'), | 
|  | Case(tensor_namedshape=[None, [2]], | 
|  | align_names=[None, 'N'], | 
|  | expected_sizes=None, | 
|  | expected_error='absolute position from the right'), | 
|  | Case(tensor_namedshape=[[None, 'N'], [2, 3]], | 
|  | align_names=[None, 'C', 'N'], | 
|  | expected_sizes=None, | 
|  | expected_error='absolute position from the right'), | 
|  | ] | 
|  |  | 
|  | for test in tests: | 
|  | _test(*test) | 
|  |  | 
|  | @unittest.skip("Not implemented yet") | 
|  | def test_align_tensors(self): | 
|  | def reference_fn(*tensors): | 
|  | longest_names = tensors[0].names | 
|  | for tensor in tensors: | 
|  | if len(tensor.names) > len(longest_names): | 
|  | longest_names = tensor.names | 
|  | return [tensor.align_to(*longest_names) for tensor in tensors] | 
|  |  | 
|  | x = torch.empty(1, 1, names=('N', 'H')) | 
|  | y = torch.empty(2, 3, 5, names=('N', 'C', 'H')) | 
|  | z = torch.empty(2, names=('N',)) | 
|  | output = torch.align_tensors(x, y, z) | 
|  | expected_tensors = reference_fn(x, y, z) | 
|  | for tensor, expected in zip(output, expected_tensors): | 
|  | self.assertTensorDataAndNamesEqual(tensor, expected) | 
|  |  | 
|  | def test_mm(self): | 
|  | for device in torch.testing.get_all_device_types(): | 
|  | self._test_name_inference( | 
|  | torch.mm, device=device, | 
|  | args=(create('N:3,C:2'), create('W:2,H:5')), | 
|  | expected_names=('N', 'H')) | 
|  |  | 
|  | # left arg is unnamed | 
|  | self._test_name_inference( | 
|  | torch.mm, device=device, | 
|  | args=(create('3,2'), create('W:2,H:5')), | 
|  | expected_names=(None, 'H')) | 
|  |  | 
|  | # right arg is unnamed | 
|  | self._test_name_inference( | 
|  | torch.mm, device=device, | 
|  | args=(create('N:3,C:2'), create('2,5')), | 
|  | expected_names=('N', None)) | 
|  |  | 
|  | # out= | 
|  | self._test_name_inference( | 
|  | out_fn(torch.mm), device=device, | 
|  | args=(create('0'), create('N:3,C:2'), create('W:2,H:5')), | 
|  | expected_names=('N', 'H')) | 
|  |  | 
|  | self._test_name_inference( | 
|  | torch.mm, device=device, | 
|  | args=(create('N:3,C:2'), create('W:2,N:5')), | 
|  | maybe_raises_regex='with duplicate names') | 
|  |  | 
|  | def test_expand(self): | 
|  | for device in torch.testing.get_all_device_types(): | 
|  | self._test_name_inference( | 
|  | Tensor.expand, device=device, | 
|  | args=(create('D:1'), [3]), expected_names=('D')) | 
|  |  | 
|  | self._test_name_inference( | 
|  | Tensor.expand, device=device, | 
|  | args=(create('H:3,W:2'), [10, 3, 3, 2]), | 
|  | expected_names=(None, None, 'H', 'W')) | 
|  |  | 
|  | self._test_name_inference( | 
|  | Tensor.expand, device=device, | 
|  | args=(create('3, 2'), [10, 3, 3, 2]), | 
|  | expected_names=(None, None, None, None)) | 
|  |  | 
|  | def test_addmm(self): | 
|  | for device in torch.testing.get_all_device_types(): | 
|  | # full names | 
|  | self._test_name_inference( | 
|  | torch.addmm, device=device, | 
|  | args=(create('N:3,H:5'), create('N:3,C:2'), create('W:2,H:5')), | 
|  | expected_names=('N', 'H')) | 
|  |  | 
|  | # no name on bias | 
|  | self._test_name_inference( | 
|  | torch.addmm, device=device, | 
|  | args=(create('3,5'), create('N:3,C:2'), create('W:2,H:5')), | 
|  | expected_names=('N', 'H')) | 
|  |  | 
|  | # partially named bias | 
|  | self._test_name_inference( | 
|  | torch.addmm, device=device, | 
|  | args=(create('N:3,None:5'), create('N:3,C:2'), create('W:2,H:5')), | 
|  | expected_names=('N', 'H')) | 
|  |  | 
|  | # out= | 
|  | self._test_name_inference( | 
|  | out_fn(torch.addmm), device=device, | 
|  | args=(create('0'), create('N:3,None:5'), create('N:3,C:2'), create('W:2,H:5')), | 
|  | expected_names=('N', 'H')) | 
|  |  | 
|  | # inplace | 
|  | self._test_name_inference( | 
|  | torch.Tensor.addmm_, device=device, | 
|  | args=(create('N:3,H:5'), create('N:3,C:2'), create('W:2,H:5')), | 
|  | expected_names=('N', 'H')) | 
|  |  | 
|  | self._test_name_inference( | 
|  | torch.addmm, device=device, | 
|  | args=(create('N:3,H:5'), create('N:3,C:2'), create('W:2,N:5')), | 
|  | maybe_raises_regex='with duplicate names') | 
|  |  | 
|  | def test_bmm(self): | 
|  | for device in torch.testing.get_all_device_types(): | 
|  | # full names | 
|  | self._test_name_inference( | 
|  | torch.bmm, device=device, | 
|  | args=(create('N:7,A:3,B:2'), create('N:7,A:2,B:5')), | 
|  | expected_names=('N', 'A', 'B')) | 
|  |  | 
|  | # no name on left tensor | 
|  | self._test_name_inference( | 
|  | torch.bmm, device=device, | 
|  | args=(create('7,3,2'), create('N:7,A:2,B:5')), | 
|  | expected_names=('N', None, 'B')) | 
|  |  | 
|  | # no name on right tensor | 
|  | self._test_name_inference( | 
|  | torch.bmm, device=device, | 
|  | args=(create('N:7,A:3,B:2'), create('7,2,5')), | 
|  | expected_names=('N', 'A', None)) | 
|  |  | 
|  | # out= | 
|  | self._test_name_inference( | 
|  | out_fn(torch.bmm), device=device, | 
|  | args=(create('0'), create('N:7,A:3,B:2'), create('N:7,A:2,B:5')), | 
|  | expected_names=('N', 'A', 'B')) | 
|  |  | 
|  | # duplicate names after mm | 
|  | self._test_name_inference( | 
|  | torch.bmm, device=device, | 
|  | args=(create('N:7,A:3,B:2'), create('N:7,B:2,A:5')), | 
|  | maybe_raises_regex='with duplicate names') | 
|  |  | 
|  | # matching error (batch dimensions must be alignable) | 
|  | self._test_name_inference( | 
|  | torch.bmm, device=device, | 
|  | args=(create('N:3,A:3,B:3'), create('M:3,A:3,B:3')), | 
|  | maybe_raises_regex='do not match') | 
|  |  | 
|  | # misalignment (batch dimension is getting contracted) | 
|  | self._test_name_inference( | 
|  | torch.bmm, device=device, | 
|  | args=(create('N:3,A:3,B:3'), create('None:3,N:3,B:3')), | 
|  | maybe_raises_regex='misaligned') | 
|  |  | 
|  | def test_matmul(self): | 
|  | for device in torch.testing.get_all_device_types(): | 
|  | # input tensors are less than 1D | 
|  | self._test_name_inference( | 
|  | torch.matmul, device=device, | 
|  | args=(create(''), create('A:2')), | 
|  | maybe_raises_regex='at least 1D') | 
|  | self._test_name_inference( | 
|  | torch.matmul, device=device, | 
|  | args=(create('A:2'), create('')), | 
|  | maybe_raises_regex='at least 1D') | 
|  |  | 
|  | # 1D @ 1D | 
|  | self._test_name_inference( | 
|  | torch.matmul, device=device, | 
|  | args=(create('A:2'), create('B:2')), | 
|  | expected_names=[]) | 
|  |  | 
|  | # ND @ 1D | 
|  | self._test_name_inference( | 
|  | torch.matmul, device=device, | 
|  | args=(create('A:3,C:2'), create('B:2')), | 
|  | expected_names=['A']) | 
|  | self._test_name_inference( | 
|  | torch.matmul, device=device, | 
|  | args=(create('A:5,C:3,D:2'), create('B:2')), | 
|  | expected_names=['A', 'C']) | 
|  |  | 
|  | # 1D @ ND | 
|  | self._test_name_inference( | 
|  | torch.matmul, device=device, | 
|  | args=(create('C:2'), create('A:2,B:3')), | 
|  | expected_names=['B']) | 
|  | self._test_name_inference( | 
|  | torch.matmul, device=device, | 
|  | args=(create('C:2'), create('A:3,B:2,D:5')), | 
|  | expected_names=['A', 'D']) | 
|  |  | 
|  | # 2D @ 2D | 
|  | self._test_name_inference( | 
|  | torch.matmul, device=device, | 
|  | args=(create('A:3,B:2'), create('A:2,B:3')), | 
|  | expected_names=['A', 'B']) | 
|  | self._test_name_inference( | 
|  | torch.matmul, device=device, | 
|  | args=(create('A:3,B:2'), create('B:2,A:5')), | 
|  | maybe_raises_regex='with duplicate names') | 
|  |  | 
|  | # ND @ ND where N >= 2 | 
|  | self._test_name_inference( | 
|  | torch.matmul, device=device, | 
|  | args=(create('C:5,A:3,B:2'), create('A:2,B:3')), | 
|  | expected_names=['C', 'A', 'B']) | 
|  | self._test_name_inference( | 
|  | torch.matmul, device=device, | 
|  | args=(create('C:5,A:3,B:2'), create('None:1,A:2,B:3')), | 
|  | expected_names=['C', 'A', 'B']) | 
|  | self._test_name_inference( | 
|  | torch.matmul, device=device, | 
|  | args=(create('C:5,A:3,B:2'), create('None:2,None:1,A:2,B:3')), | 
|  | expected_names=[None, 'C', 'A', 'B']) | 
|  |  | 
|  | # out= | 
|  | self._test_name_inference( | 
|  | out_fn(torch.matmul), device=device, | 
|  | args=(create('0'), create('N:7,A:3,B:2'), create('N:7,A:2,B:5')), | 
|  | expected_names=('N', 'A', 'B')) | 
|  |  | 
|  | # duplicate names after mm | 
|  | self._test_name_inference( | 
|  | torch.bmm, device=device, | 
|  | args=(create('N:7,A:3,B:2'), create('N:7,B:2,A:5')), | 
|  | maybe_raises_regex='with duplicate names') | 
|  |  | 
|  | # misalignment (batch dimension is getting contracted) | 
|  | self._test_name_inference( | 
|  | torch.matmul, device=device, | 
|  | args=(create('N:3,A:3,B:3'), create('A:3,N:3,B:3')), | 
|  | maybe_raises_regex='do not match') | 
|  |  | 
|  | def test_mv(self): | 
|  | for device in torch.testing.get_all_device_types(): | 
|  | self._test_name_inference( | 
|  | torch.mv, device=device, | 
|  | args=(create('N:3,C:2'), create('W:2')), | 
|  | expected_names=('N',)) | 
|  |  | 
|  | # left arg is unnamed | 
|  | self._test_name_inference( | 
|  | torch.mv, device=device, | 
|  | args=(create('3,2'), create('W:2')), | 
|  | expected_names=(None,)) | 
|  |  | 
|  | # right arg is unnamed | 
|  | self._test_name_inference( | 
|  | torch.mv, device=device, | 
|  | args=(create('N:3,C:2'), create('2')), | 
|  | expected_names=('N',)) | 
|  |  | 
|  | # out= | 
|  | self._test_name_inference( | 
|  | out_fn(torch.mv), device=device, | 
|  | args=(create('0'), create('N:3,C:2'), create('W:2')), | 
|  | expected_names=('N',)) | 
|  |  | 
|  | def test_addmv(self): | 
|  | for device in torch.testing.get_all_device_types(): | 
|  | # full names | 
|  | self._test_name_inference( | 
|  | torch.addmv, device=device, | 
|  | args=(create('N:3'), create('N:3,C:2'), create('H:2')), | 
|  | expected_names=['N']) | 
|  |  | 
|  | # no name on bias | 
|  | self._test_name_inference( | 
|  | torch.addmv, device=device, | 
|  | args=(create('3'), create('N:3,C:2'), create('H:2')), | 
|  | expected_names=('N',)) | 
|  |  | 
|  | # out= | 
|  | self._test_name_inference( | 
|  | out_fn(torch.addmv), device=device, | 
|  | args=(create('0'), create('N:3'), create('N:3,C:2'), create('H:2')), | 
|  | expected_names=('N',)) | 
|  |  | 
|  | # inplace | 
|  | self._test_name_inference( | 
|  | torch.Tensor.addmv_, device=device, | 
|  | args=(create('N:3'), create('N:3,C:2'), create('H:2')), | 
|  | expected_names=('N',)) | 
|  |  | 
|  | def test_autograd_ignores_names(self): | 
|  | # sigmoid forward is supported by named tensors, but sigmoid_backward | 
|  | # is not (see native_functions.yaml). Test that autograd ignores names | 
|  | # and that the sigmoid_backward succeeds. | 
|  | x = torch.randn(3, 3, names=('N', 'C'), requires_grad=True) | 
|  | x.sigmoid().sum().backward() | 
|  |  | 
|  | def test_tensor_grad_is_unnamed(self): | 
|  | x = torch.randn(3, 3, names=(None, None), requires_grad=True) | 
|  | y = torch.randn(3, 3, names=('N', 'C'), requires_grad=True) | 
|  | (x * y).sum().backward() | 
|  |  | 
|  | # Check that names weren't propagated | 
|  | self.assertEqual(y.grad.names, [None, None]) | 
|  | self.assertEqual(x.grad.names, [None, None]) | 
|  |  | 
|  | def test_autograd_warns_named_grad(self): | 
|  | base = torch.randn(3, 3, names=('N', 'C')) | 
|  | named_grad = base.clone() | 
|  | base.requires_grad_() | 
|  |  | 
|  | with warnings.catch_warnings(record=True) as warns: | 
|  | # Cause all warnings to always be triggered. | 
|  | warnings.simplefilter("always") | 
|  | base.clone().backward(named_grad) | 
|  | self.assertEqual(len(warns), 1) | 
|  | self.assertTrue( | 
|  | str(warns[0].message).startswith('Autograd was passed a named grad tensor')) | 
|  |  | 
|  | def test_nyi_dimname_overload_msg(self): | 
|  | x = torch.randn(3, 3) | 
|  | with self.assertRaisesRegex(RuntimeError, "squeeze: You passed a dimname"): | 
|  | x.squeeze_("N") | 
|  |  | 
|  | def test_dot(self): | 
|  | for device in torch.testing.get_all_device_types(): | 
|  | # torch.dot ignores the names of both tensors | 
|  | self._test_name_inference( | 
|  | torch.dot, device=device, | 
|  | args=(create('C:2'), create('W:2')), | 
|  | expected_names=[]) | 
|  |  | 
|  | def test_comparison_ops(self): | 
|  | for device in torch.testing.get_all_device_types(): | 
|  | a = torch.randn(3, 3, names=('N', 'C'), device=device) | 
|  | b = torch.randn(3, 3, names=('N', 'C'), device=device) | 
|  | scalar = torch.randn([], device=device) | 
|  |  | 
|  | self.assertEqual((a == b).names, ['N', 'C']) | 
|  | self.assertEqual((a != b).names, ['N', 'C']) | 
|  | self.assertEqual((a > b).names, ['N', 'C']) | 
|  | self.assertEqual((a < b).names, ['N', 'C']) | 
|  | self.assertEqual((a >= b).names, ['N', 'C']) | 
|  | self.assertEqual((a <= b).names, ['N', 'C']) | 
|  |  | 
|  | self.assertEqual((a == 1).names, ['N', 'C']) | 
|  | self.assertEqual((a != 1).names, ['N', 'C']) | 
|  | self.assertEqual((a > 1).names, ['N', 'C']) | 
|  | self.assertEqual((a < 1).names, ['N', 'C']) | 
|  | self.assertEqual((a >= 1).names, ['N', 'C']) | 
|  | self.assertEqual((a <= 1).names, ['N', 'C']) | 
|  |  | 
|  | self.assertEqual((a == scalar).names, ['N', 'C']) | 
|  | self.assertEqual((a != scalar).names, ['N', 'C']) | 
|  | self.assertEqual((a > scalar).names, ['N', 'C']) | 
|  | self.assertEqual((a < scalar).names, ['N', 'C']) | 
|  | self.assertEqual((a >= scalar).names, ['N', 'C']) | 
|  | self.assertEqual((a <= scalar).names, ['N', 'C']) | 
|  |  | 
|  | res = torch.empty(3, 3, dtype=torch.bool, device=device) | 
|  | torch.eq(a, b, out=res) | 
|  | self.assertEqual(res.names, ['N', 'C']) | 
|  | torch.ne(a, b, out=res) | 
|  | self.assertEqual(res.names, ['N', 'C']) | 
|  | torch.lt(a, b, out=res) | 
|  | self.assertEqual(res.names, ['N', 'C']) | 
|  | torch.gt(a, b, out=res) | 
|  | self.assertEqual(res.names, ['N', 'C']) | 
|  | torch.le(a, b, out=res) | 
|  | self.assertEqual(res.names, ['N', 'C']) | 
|  | torch.ge(a, b, out=res) | 
|  | self.assertEqual(res.names, ['N', 'C']) | 
|  |  | 
|  | res = torch.isnan(a) | 
|  | self.assertEqual(res.names, ['N', 'C']) | 
|  |  | 
|  | res = torch.isinf(a) | 
|  | self.assertEqual(res.names, ['N', 'C']) | 
|  |  | 
|  |  | 
|  | if __name__ == '__main__': | 
|  | run_tests() |