|  | # Owner(s): ["oncall: jit"] | 
|  |  | 
|  | import os | 
|  | import sys | 
|  | import torch | 
|  | from torch.utils._pytree import tree_map | 
|  | import unittest | 
|  |  | 
|  | from torch.testing._internal.common_utils import run_tests | 
|  | from torch.fx.operator_schemas import normalize_function | 
|  | from torch._subclasses.schema_check_mode import SchemaCheckMode | 
|  | from torch.utils._python_dispatch import TorchDispatchMode | 
|  | from torch.testing._internal.common_methods_invocations import op_db | 
|  | from torch.testing._internal.jit_utils import JitTestCase | 
|  | from torch.testing._internal.common_device_type import ops, OpDTypes, instantiate_device_type_tests | 
|  | pytorch_test_dir = os.path.dirname(os.path.dirname(os.path.realpath(__file__))) | 
|  | sys.path.append(pytorch_test_dir) | 
|  |  | 
|  | def secretly_aliasing(x): | 
|  | return x.view(-1) | 
|  |  | 
|  | def secretly_mutating(x): | 
|  | x.mul_(2) | 
|  | return x * 3 | 
|  |  | 
|  | def output_is_input(x): | 
|  | return x | 
|  |  | 
|  | custom_lib = torch.library.Library("bad_schemas", "DEF") | 
|  | custom_lib.define("secretly_aliasing(Tensor x) -> Tensor") | 
|  | custom_lib.define("secretly_mutating(Tensor x) -> Tensor") | 
|  | custom_lib.define("output_is_input(Tensor(a) x) -> Tensor(a)") | 
|  |  | 
|  | custom_lib_cpu = torch.library.Library("bad_schemas", "IMPL", "CPU") | 
|  | custom_lib_cpu.impl("secretly_aliasing", secretly_aliasing) | 
|  | custom_lib_cpu.impl("secretly_mutating", secretly_mutating) | 
|  | custom_lib_cpu.impl("output_is_input", output_is_input) | 
|  |  | 
|  | custom_lib_meta = torch.library.Library("bad_schemas", "IMPL", "Meta") | 
|  | custom_lib_meta.impl("secretly_aliasing", secretly_aliasing) | 
|  | custom_lib_meta.impl("secretly_mutating", secretly_mutating) | 
|  | custom_lib_meta.impl("output_is_input", output_is_input) | 
|  |  | 
|  | # This TorchDispatchTensor Subclass is used to simulate an incorrect schema | 
|  | # which is then used to test that SchemaCheckMode behaves as expected | 
|  |  | 
|  | class IncorrectAliasTensor(torch.Tensor): | 
|  | ALIAS_ARG_OUT = {"aten::add"} | 
|  | ALIAS_OUT_OUT = {"aten::aminmax"} | 
|  | MUTATE_ARGS_OUT = {"aten::sub"} | 
|  |  | 
|  | elem: torch.Tensor | 
|  |  | 
|  | __slots__ = ['elem'] | 
|  |  | 
|  | __torch_function__ = torch._C._disabled_torch_function_impl | 
|  |  | 
|  | @staticmethod | 
|  | def __new__(cls, elem, *args, **kwargs): | 
|  | # The wrapping tensor (IncorrectAliasTensor) shouldn't hold any | 
|  | # memory for the class in question, but it should still | 
|  | # advertise the same device as before | 
|  | r = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined] | 
|  | cls, elem.size(), | 
|  | strides=elem.stride(), storage_offset=elem.storage_offset(), | 
|  | # TODO: clone storage aliasing | 
|  | dtype=elem.dtype, layout=elem.layout, | 
|  | device=elem.device, requires_grad=kwargs.get("requires_grad", False) | 
|  | ) | 
|  | # ...the real tensor is held as an element on the tensor. | 
|  | r.elem = elem.detach() if r.requires_grad else elem | 
|  | return r | 
|  |  | 
|  | def __repr__(self): | 
|  | return super().__repr__(tensor_contents=f"{self.elem}") | 
|  |  | 
|  | @classmethod | 
|  | def __torch_dispatch__(cls, func, types, args=(), kwargs=None): | 
|  | def unwrap(e): | 
|  | return e.elem if isinstance(e, cls) else e | 
|  |  | 
|  | def wrap(e): | 
|  | return cls(e) if isinstance(e, torch.Tensor) else e | 
|  | unwrapped_args = tree_map(unwrap, args) | 
|  | out = func(*unwrapped_args, **tree_map(unwrap, kwargs)) | 
|  | if func._schema.name in IncorrectAliasTensor.ALIAS_ARG_OUT: | 
|  | args[0].elem = out | 
|  | if func._schema.name in IncorrectAliasTensor.MUTATE_ARGS_OUT: | 
|  | args[0].elem = torch.rand(args[0].elem.shape) | 
|  | if func._schema.name in IncorrectAliasTensor.ALIAS_OUT_OUT: | 
|  | incorrect_out = list(out) | 
|  | incorrect_out[0] = incorrect_out[1] | 
|  | return tree_map(wrap, tuple(incorrect_out)) | 
|  |  | 
|  | return tree_map(wrap, out) | 
|  |  | 
|  | # Tests various schema checking functionalities. | 
|  | class TestSchemaCheck(JitTestCase): | 
|  | # Tests that SchemaCheckMode records operator order with grad | 
|  | def test_schema_check_mode_operator_order(self): | 
|  | with SchemaCheckMode() as schema_check: | 
|  | x = torch.rand((3, 3), requires_grad=True) | 
|  | x.relu().sin() | 
|  | self.assertEqual(["aten::rand", "aten::relu", "aten::detach", "aten::sin"], schema_check.ops) | 
|  |  | 
|  | # Tests that SchemaCheckMode records operator order without grad | 
|  | def test_schema_check_mode_operator_order_without_grad(self): | 
|  | with SchemaCheckMode() as schema_check: | 
|  | x = torch.rand((3, 3), requires_grad=False) | 
|  | x.relu().sin() | 
|  | self.assertEqual(["aten::rand", "aten::relu", "aten::sin"], schema_check.ops) | 
|  |  | 
|  | # Tests that SchemaCheckMode records mutations and aliases with none expected | 
|  | def test_schema_check_mode_mutated_aliasing_none(self): | 
|  | # NB: previously requires_grad=True, but this induces a detach for | 
|  | # saved variable | 
|  | x = torch.rand((3, 3)) | 
|  | with SchemaCheckMode() as schema_check: | 
|  | actual = x.relu().sin() | 
|  | self.assertEqual([], schema_check.mutated) | 
|  | self.assertEqual([], schema_check.aliasing) | 
|  |  | 
|  | # Tests that SchemaCheckMode records mutations and aliases with mutation expected | 
|  | def test_schema_check_mode_mutated_aliasing_mutation(self): | 
|  | actual = torch.rand((3, 3), requires_grad=False) | 
|  | with SchemaCheckMode() as schema_check: | 
|  | actual.sinh_() | 
|  | self.assertEqual([('aten::sinh_', 'input')], schema_check.mutated) | 
|  | self.assertEqual([('aten::sinh_', 'input', 'output_0')], schema_check.aliasing) | 
|  |  | 
|  | # Tests that SchemaCheckMode records mutations and aliases with resize_ | 
|  | def test_schema_check_mode_mutated_aliasing_resize_(self): | 
|  | actual = torch.rand((3, 3), requires_grad=False) | 
|  | with SchemaCheckMode() as schema_check: | 
|  | actual.resize_(9) | 
|  | self.assertEqual([('aten::resize_', 'input')], schema_check.mutated) | 
|  | self.assertEqual([('aten::resize_', 'input', 'output_0')], schema_check.aliasing) | 
|  |  | 
|  | # Tests that SchemaCheckMode records mutations and aliases with aliasing inputs | 
|  | def test_schema_check_mode_mutated_aliasing_aliasing_inputs(self): | 
|  | actual = torch.rand((3, 3)) | 
|  | y = actual | 
|  | with SchemaCheckMode() as schema_check: | 
|  | actual.add_(y) | 
|  | self.assertEqual( | 
|  | [ | 
|  | ('aten::add_', 'input'), | 
|  | ('aten::add_', 'other') | 
|  | ], | 
|  | schema_check.mutated | 
|  | ) | 
|  | self.assertEqual( | 
|  | [ | 
|  | ('aten::add_', 'input', 'output_0'), | 
|  | ('aten::add_', 'other', 'output_0') | 
|  | ], | 
|  | schema_check.aliasing | 
|  | ) | 
|  |  | 
|  | # Tests that SchemaCheckMode records mutations and alias with as_strided | 
|  | def test_schema_check_mode_mutated_aliasing_as_strided(self): | 
|  | x = torch.rand((3, 6, 4)) | 
|  | with SchemaCheckMode() as schema_check: | 
|  | x.as_strided_([3, 6, 4], [9, 1, 1]) | 
|  | self.assertEqual( | 
|  | [ | 
|  | ('aten::as_strided_', 'input') | 
|  | ], | 
|  | schema_check.mutated | 
|  | ) | 
|  | self.assertEqual( | 
|  | [ | 
|  | ('aten::as_strided_', 'input', 'output_0') | 
|  | ], | 
|  | schema_check.aliasing | 
|  | ) | 
|  |  | 
|  | # Tests that SchemaCheckMode records mutations and aliases with multiple outputs | 
|  | def test_schema_check_mode_mutated_aliasing_multiple_outputs(self): | 
|  | x = torch.arange(9.) | 
|  | m_actual = torch.arange(9.) | 
|  | e_actual = torch.zeros([9], dtype=torch.int32) | 
|  | with SchemaCheckMode() as schema_check: | 
|  | torch.frexp(x, out=(m_actual, e_actual)) | 
|  | self.assertEqual( | 
|  | [ | 
|  | ('aten::frexp', 'mantissa'), | 
|  | ('aten::frexp', 'exponent') | 
|  | ], | 
|  | schema_check.mutated | 
|  | ) | 
|  | self.assertEqual( | 
|  | [ | 
|  | ('aten::frexp', 'mantissa', 'output_0'), | 
|  | ('aten::frexp', 'exponent', 'output_1') | 
|  | ], | 
|  | schema_check.aliasing | 
|  | ) | 
|  |  | 
|  | # Tests that SchemaCheckMode records mutations and aliases with aliasing outputs | 
|  | def test_schema_check_mode_mutated_aliasing_aliasing_outputs(self): | 
|  | x = torch.rand((3, 3)) | 
|  | actual = torch.zeros(3) | 
|  | with SchemaCheckMode() as schema_check: | 
|  | torch.aminmax(x, dim=0, out=[actual, actual]) | 
|  | self.assertEqual( | 
|  | [ | 
|  | ('aten::aminmax', 'min'), | 
|  | ('aten::aminmax', 'max') | 
|  | ], | 
|  | schema_check.mutated | 
|  | ) | 
|  | self.assertEqual( | 
|  | [ | 
|  | ('aten::aminmax', 'min', 'output_0'), | 
|  | ('aten::aminmax', 'min', 'output_1'), | 
|  | ('aten::aminmax', 'max', 'output_0'), | 
|  | ('aten::aminmax', 'max', 'output_1') | 
|  | ], | 
|  | schema_check.aliasing | 
|  | ) | 
|  |  | 
|  | # Tests that SchemaCheckMode wraps torch.Tensor | 
|  | def test_schema_check_mode_functionality(self): | 
|  | x = torch.rand((3, 3), requires_grad=True) | 
|  | expected = x.relu().sin() | 
|  | with SchemaCheckMode(): | 
|  | actual = x.relu().sin() | 
|  | self.assertEqual(expected, actual) | 
|  |  | 
|  | # Tests that SchemaCheckMode wraps torch.Tensor when an argument's default is overriden | 
|  | def test_schema_check_mode_functionality_default_replaced(self): | 
|  | x = torch.rand((3, 3), requires_grad=True) | 
|  | expected = x.add(x, alpha=2) | 
|  | with SchemaCheckMode(): | 
|  | actual = x.add(x, alpha=2) | 
|  | self.assertEqual(expected, actual) | 
|  |  | 
|  | # Tests that SchemaCheckMode wraps torch.Tensor when there is a Tensor[] argument | 
|  | def test_schema_check_mode_functionality_list_input(self): | 
|  | a = torch.rand((3, 3)) | 
|  | b = torch.rand((3, 3)) | 
|  | c = torch.rand((3, 3)) | 
|  | expected = torch.linalg.multi_dot([a, b, c]) | 
|  | with SchemaCheckMode(): | 
|  | actual = torch.linalg.multi_dot([a, b, c]) | 
|  | self.assertEqual(expected, actual) | 
|  |  | 
|  | # Tests that SchemaCheckMode wraps torch.Tensor with an op that has the (a -> *) notation | 
|  | def test_schema_check_mode_functionality_wildcard_after(self): | 
|  | x = torch.rand((3, 3)) | 
|  | expected = x.chunk(6) | 
|  | with SchemaCheckMode(): | 
|  | actual = x.chunk(6) | 
|  | self.assertEqual(expected, actual) | 
|  |  | 
|  | # Tests that SchemaCheckMode wraps torch.Tensor when there is a kwarg tensor input | 
|  | @unittest.skipIf(not torch._C.has_spectral, "ATen not built with FFT.") | 
|  | def test_schema_check_mode_functionality_kwarg_tensor(self): | 
|  | x = torch.rand((3, 5)) | 
|  | w = torch.rand((4)) | 
|  | expected = torch.stft(x, 4, win_length=4, window=w, return_complex=True) | 
|  | with SchemaCheckMode(): | 
|  | actual = torch.stft(x, 4, win_length=4, window=w, return_complex=True) | 
|  | self.assertEqual(expected, actual) | 
|  |  | 
|  | # Tests that SchemaCheckMode wraps torch.Tensor with a mutable op | 
|  | def test_schema_check_mode_functionality_mutable_inputs(self): | 
|  | expected = torch.rand((3, 3), requires_grad=False) | 
|  | actual = torch.clone(expected) | 
|  | expected.sinh_() | 
|  | with SchemaCheckMode(): | 
|  | actual.sinh_() | 
|  | self.assertEqual(expected, actual) | 
|  |  | 
|  | # Tests that SchemaCheckMode wraps Torch.tensor when inputs alias | 
|  | def test_schema_check_mode_functionality_aliasing_inputs(self): | 
|  | expected = torch.rand((3, 3)) | 
|  | x = expected | 
|  | actual = torch.clone(expected) | 
|  | y = actual | 
|  | expected.add_(x) | 
|  | with SchemaCheckMode(): | 
|  | actual.add_(y) | 
|  | self.assertEqual(expected, actual) | 
|  |  | 
|  | # Tests that SchemaCheckMode wraps Torch.tensor with multiple tensor outputs | 
|  | def test_schema_check_mode_functionality_with_multiple_outputs(self): | 
|  | x = torch.arange(9.) | 
|  | m_expected, e_expected = torch.frexp(x) | 
|  | m_actual = torch.arange(9.) | 
|  | e_actual = torch.zeros([9], dtype=torch.int32) | 
|  | with SchemaCheckMode(): | 
|  | torch.frexp(x, out=(m_actual, e_actual)) | 
|  | self.assertEqual(m_expected, m_actual) | 
|  | self.assertEqual(e_expected, e_actual) | 
|  |  | 
|  | # Tests that SchemaCheckMode wraps Torch.tensor with aliasing ouputs due to aliasing inputs | 
|  | def test_schema_check_mode_functionality_with_multiple_outputs_aliasing(self): | 
|  | x = torch.rand((3, 3)) | 
|  | actual = torch.zeros(3) | 
|  | with SchemaCheckMode(): | 
|  | torch.aminmax(x, dim=0, out=[actual, actual]) | 
|  | self.assertEqual(torch.amax(x, dim=0), actual) | 
|  |  | 
|  | # Tests that SchemaCheckMode wraps Torch.tensor in ops with real Device input | 
|  | def test_schema_check_mode_functionality_device_input(self): | 
|  | with SchemaCheckMode(): | 
|  | x = torch.rand((3, 3), device="cpu", dtype=torch.double) | 
|  | y = x + x | 
|  | self.assertEqual(x + x, y) | 
|  |  | 
|  | # Tests that SchemaCheckMode wraps Torch.tensor in special training op edge case | 
|  | def test_schema_check_mode_functionality_training_op(self): | 
|  | x = torch.rand((3, 3), requires_grad=True) | 
|  | batch = torch.nn.BatchNorm1d(3, track_running_stats=True) | 
|  | expected = batch(x) | 
|  | with SchemaCheckMode(): | 
|  | actual = batch(x) | 
|  | self.assertEqual(expected, actual) | 
|  |  | 
|  | # Tests that SchemaCheckMode wraps Torch.tensor with nested training op edge case | 
|  | def test_schema_check_mode_functionality_nested_training_op(self): | 
|  | actual = torch.rand((3, 3)) | 
|  | batch = torch.nn.BatchNorm1d(3, track_running_stats=True) | 
|  | expected = torch.clone(actual) | 
|  | expected.sinh_() | 
|  | expected.tanh_() | 
|  | expected.relu_() | 
|  | expected = batch(expected) | 
|  |  | 
|  | with SchemaCheckMode(): | 
|  | actual.sinh_() | 
|  | actual.tanh_() | 
|  | actual.relu_() | 
|  | actual = batch(actual) | 
|  | self.assertEqual(expected, actual) | 
|  |  | 
|  | # Tests that SchemaCheckMode wraps Torch.tensor with empty list input | 
|  | def test_schema_check_mode_empty_list_input(self): | 
|  | expected = torch.atleast_1d([]) | 
|  | with SchemaCheckMode(): | 
|  | actual = torch.atleast_1d([]) | 
|  | self.assertEqual(expected, actual) | 
|  |  | 
|  | # Tests that an exception is raised for a mismatching mutation | 
|  | def test_mutation_check_fail(self): | 
|  | with self.assertRaisesRegex(RuntimeError, "Argument input is not defined as mutable but was mutated"): | 
|  | x = torch.rand((3, 3)) | 
|  | y = torch.rand((3, 3)) | 
|  | with SchemaCheckMode(): | 
|  | IncorrectAliasTensor(x).sub(IncorrectAliasTensor(y)) | 
|  |  | 
|  | # # Tests that an exception is raised for a mismatching mutation over multiple ops | 
|  | def test_mutation_check_fail_multiple_operators(self): | 
|  | with self.assertRaisesRegex(RuntimeError, "Argument input is not defined as mutable but was mutated"): | 
|  | x = torch.rand((3, 3)) | 
|  | y = torch.rand((3, 3)) | 
|  | with SchemaCheckMode(): | 
|  | IncorrectAliasTensor(x).sin().cos().sub(IncorrectAliasTensor(y)) | 
|  |  | 
|  | # Tests that an exception is raised for a mismatching alias | 
|  | def test_alias_check_fail_simple(self): | 
|  | with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"): | 
|  | x = torch.rand((3, 3), requires_grad=True) | 
|  | y = torch.rand((3, 3)) | 
|  | with SchemaCheckMode(): | 
|  | IncorrectAliasTensor(x).add(IncorrectAliasTensor(y), alpha=2) | 
|  |  | 
|  | # Tests that an exception is raised for a mismatching alias over multiple ops | 
|  | def test_alias_check_fail_multiple_operators(self): | 
|  | with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"): | 
|  | x = torch.rand((3, 3), requires_grad=True) | 
|  | y = torch.zeros((3, 3), requires_grad=True) | 
|  | with SchemaCheckMode(): | 
|  | IncorrectAliasTensor(x).sin().relu().add(IncorrectAliasTensor(y), alpha=2) | 
|  |  | 
|  | # Tests that an exception is raised for a centered mismatching alias over multiple ops | 
|  | def test_alias_check_fail_multiple_operators_centered(self): | 
|  | with self.assertRaisesRegex(RuntimeError, "Argument input is not defined to alias output but was aliasing"): | 
|  | x = torch.rand((3, 3), requires_grad=True) | 
|  | y = torch.zeros((3, 3), requires_grad=True) | 
|  | with SchemaCheckMode(): | 
|  | IncorrectAliasTensor(x).sin().add(IncorrectAliasTensor(y), alpha=2).relu() | 
|  |  | 
|  | # Tests that an exception is raised for a centered mismatching alias over multiple ops | 
|  | def test_alias_check_fail_outputs_unexpectedly_aliasing(self): | 
|  | with self.assertRaisesRegex(RuntimeError, "Outputs 0 and 1 alias unexpectedly"): | 
|  | x = torch.rand((3, 3)) | 
|  | with SchemaCheckMode() as s: | 
|  | IncorrectAliasTensor(x).aminmax(dim=0) | 
|  |  | 
|  | # When this file was written, python op registration didn't exist. | 
|  | # It's probably worth re-writing the entire file to use it, | 
|  | # but instead I just added extra tests. | 
|  | def test_alias_check_fail_custom_ops_secretly_aliasing(self): | 
|  | def f(x): | 
|  | return torch.ops.bad_schemas.secretly_aliasing(x) | 
|  |  | 
|  | x = torch.rand((3, 3)) | 
|  | with self.assertRaisesRegex(RuntimeError, "not defined to alias output but was aliasing"): | 
|  | with SchemaCheckMode() as s: | 
|  | out = f(x) | 
|  |  | 
|  | def test_alias_check_fail_custom_ops_secretly_mutating(self): | 
|  | def f(x): | 
|  | return torch.ops.bad_schemas.secretly_mutating(x) | 
|  |  | 
|  | x = torch.rand((3, 3)) | 
|  | with self.assertRaisesRegex(RuntimeError, "not defined as mutable but was mutated"): | 
|  | with SchemaCheckMode() as s: | 
|  | out = f(x) | 
|  |  | 
|  | def test_alias_check_fail_custom_ops_output_is_input(self): | 
|  | def f(x): | 
|  | return torch.ops.bad_schemas.output_is_input(x) | 
|  |  | 
|  | x = torch.rand((3, 3)) | 
|  | with self.assertRaisesRegex(RuntimeError, "are not allowed to directly return inputs"): | 
|  | with SchemaCheckMode() as s: | 
|  | out = f(x) | 
|  |  | 
|  | # Tests that is_alias_of returns as expected | 
|  | def test_is_alias_of_basic(self): | 
|  | x = torch.rand((3, 3), requires_grad=True) | 
|  | y = torch.rand((3, 3), requires_grad=True) | 
|  | y = x.add(x, alpha=2) | 
|  | self.assertTrue(torch._C._is_alias_of(x, x)) | 
|  | self.assertFalse(torch._C._is_alias_of(x, y)) | 
|  |  | 
|  | # Tests that is_alias_of returns as expected with empty containers | 
|  | def test_is_alias_of_empty_container(self): | 
|  | x = [] | 
|  | y = torch.rand((3, 3), requires_grad=True) | 
|  | self.assertFalse(torch._C._is_alias_of(x, x)) | 
|  | self.assertFalse(torch._C._is_alias_of(x, y)) | 
|  |  | 
|  | # Tests that overlaps returns as expected | 
|  | def test_overlaps_basic(self): | 
|  | x = torch.rand((3, 3), requires_grad=True) | 
|  | y = torch.rand((3, 3), requires_grad=True) | 
|  | z = [x, y] | 
|  | self.assertTrue(torch._C._overlaps(x, x)) | 
|  | self.assertFalse(torch._C._overlaps(x, y)) | 
|  | self.assertTrue(torch._C._overlaps(z, x)) | 
|  | self.assertTrue(torch._C._overlaps(z, y)) | 
|  |  | 
|  | # Tests that overlaps returns correctly with empty containers | 
|  | def test_overlaps_empty_container(self): | 
|  | x = [] | 
|  | y = [torch.rand((3, 3), requires_grad=True)] | 
|  | # Empty containers return false | 
|  | self.assertFalse(torch._C._overlaps(y, x)) | 
|  | self.assertTrue(torch._C._overlaps(y, y)) | 
|  |  | 
|  | # Tests that SchemaInfo Bindings work as expected | 
|  | def test_schema_info_bind_basic(self): | 
|  | class SchemaInfoBindTestMode(TorchDispatchMode): | 
|  | def __init__(self, test_self): | 
|  | self.test_self = test_self | 
|  |  | 
|  | def __torch_dispatch__(self, func, types, args=(), kwargs=None): | 
|  | named_arg_list = normalize_function( | 
|  | func, | 
|  | args, | 
|  | kwargs, | 
|  | normalize_to_only_use_kwargs=True | 
|  | ).kwargs | 
|  | schema_info_value_test = torch._C._SchemaInfo(func._schema) | 
|  | schema_info_values_test = torch._C._SchemaInfo(func._schema) | 
|  | self.test_self.assertFalse(schema_info_value_test.may_alias( | 
|  | torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0), | 
|  | torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1))) | 
|  | self.test_self.assertFalse(schema_info_values_test.may_alias( | 
|  | torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0), | 
|  | torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1))) | 
|  | for i in named_arg_list: | 
|  | schema_info_value_test.add_argument_value(i, named_arg_list[i]) | 
|  | schema_info_values_test.add_argument_values(named_arg_list) | 
|  | self.test_self.assertTrue(schema_info_value_test.may_alias( | 
|  | torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0), | 
|  | torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1))) | 
|  | self.test_self.assertTrue(schema_info_values_test.may_alias( | 
|  | torch._C._SchemaArgument(torch._C._SchemaArgType.input, 0), | 
|  | torch._C._SchemaArgument(torch._C._SchemaArgType.input, 1))) | 
|  |  | 
|  | return func(*args, **kwargs) | 
|  | x = torch.rand((3, 3)) | 
|  | with SchemaInfoBindTestMode(self) as schemaInfoCheck: | 
|  | x.add(x) | 
|  |  | 
|  |  | 
|  | class TestSchemaCheckModeOpInfo(JitTestCase): | 
|  | @ops(op_db, dtypes=OpDTypes.supported) | 
|  | def test_schema_correctness(self, device, dtype, op): | 
|  | # Currently torch.equal isn't supported with torch.complex32 | 
|  | # There's also errors with complex64 and complex128 | 
|  | if (dtype == torch.complex32): | 
|  | return | 
|  | for sample in op.sample_inputs(device, dtype, requires_grad=False): | 
|  | with SchemaCheckMode(): | 
|  | op(sample.input, *sample.args, **sample.kwargs) | 
|  |  | 
|  | instantiate_device_type_tests(TestSchemaCheckModeOpInfo, globals(), only_for=("cpu", "cuda")) | 
|  |  | 
|  | if __name__ == '__main__': | 
|  | run_tests() |