| # Owner(s): ["module: unknown"] | 
 |  | 
 | from functools import partial | 
 | from textwrap import dedent | 
 |  | 
 | import torch | 
 |  | 
 | from torch.testing import FileCheck | 
 | from torch.testing._internal.common_device_type import ( | 
 |     instantiate_device_type_tests, | 
 |     OpDTypes, | 
 |     ops, | 
 | ) | 
 | from torch.testing._internal.common_jit import ( | 
 |     check_against_reference, | 
 |     JitCommonTestCase, | 
 | ) | 
 | from torch.testing._internal.common_methods_invocations import op_db | 
 | from torch.testing._internal.common_utils import ( | 
 |     clone_input_helper, | 
 |     first_sample, | 
 |     IS_SANDCASTLE, | 
 |     run_tests, | 
 |     TestCase, | 
 |     unMarkDynamoStrictTest, | 
 | ) | 
 | from torch.testing._internal.jit_metaprogramming_utils import ( | 
 |     check_alias_annotation, | 
 |     create_script_fn, | 
 |     create_traced_fn, | 
 | ) | 
 | from torch.testing._internal.jit_utils import ( | 
 |     disable_autodiff_subgraph_inlining, | 
 |     is_lambda, | 
 | ) | 
 |  | 
 | # variant testing is only done with torch.float and torch.cfloat to avoid | 
 | #   excessive test times and maximize signal to noise ratio | 
 | _variant_ops = partial( | 
 |     ops, dtypes=OpDTypes.supported, allowed_dtypes=(torch.float, torch.cfloat) | 
 | ) | 
 |  | 
 |  | 
 | # Tests operators for consistency between JIT and eager, also checks | 
 | #   correctness of JIT specific alias schemas and intended | 
 | #   autodifferentiation behavior. | 
 | # Inherits from JitCommonTestCase instead of TestCase directly to share | 
 | #   functionality with original test_jit.py method operator tests | 
 | @unMarkDynamoStrictTest | 
 | class TestJit(JitCommonTestCase): | 
 |     exact_dtype = True | 
 |  | 
 |     # Tests that the forward and backward passes of operations produce the | 
 |     #   same values for the cross-product of op variants (function, method, inplace) | 
 |     #   and runtimes (eager, traced, scripted). | 
 |     # TODO WARNING: inplace x {traced, scripted} not currently tested | 
 |     @_variant_ops(op_db) | 
 |     def test_variant_consistency_jit(self, device, dtype, op): | 
 |         _requires_grad = dtype in op.supported_backward_dtypes( | 
 |             torch.device(device).type | 
 |         ) | 
 |  | 
 |         include_conjugated_inputs = op.test_conjugated_samples and dtype.is_complex | 
 |         samples = op.sample_inputs( | 
 |             device, | 
 |             dtype, | 
 |             requires_grad=_requires_grad, | 
 |             include_conjugated_inputs=include_conjugated_inputs, | 
 |         ) | 
 |  | 
 |         # Acquires variants to test | 
 |         func = op.get_op() | 
 |         method = op.get_method() | 
 |         variants = { | 
 |             # TODO: inplace tests currently fail, fix and add inplace variant | 
 |             "function": func, | 
 |             "method": method, | 
 |         } | 
 |  | 
 |         # scripting strips the torch.ops prefix from these operators | 
 |         # incorrectly; don't bother testing this case.  Count this | 
 |         # as "testing" | 
 |         if isinstance(func, torch._ops.OpOverload): | 
 |             self.skipTest("variant consistency doesn't work on torch.ops") | 
 |  | 
 |         # TODO: find better way to standardize on op registration itself.. | 
 |         has_fake_function = op.name in ["resize_", "resize_as_"] | 
 |  | 
 |         if has_fake_function: | 
 |             variants = {"method": getattr(torch.Tensor, op.name)} | 
 |             samples = op.sample_inputs(device, dtype, requires_grad=False) | 
 |  | 
 |         tested = False | 
 |         for sample in samples: | 
 |             # Test traced and scripted consistency | 
 |             for func_type, variant in variants.items(): | 
 |                 if variant is None: | 
 |                     continue | 
 |  | 
 |                 # scripting and check_alias_analysis do not work with lambdas | 
 |                 # lambdas are typically used as a way to simulate methods without | 
 |                 # functional variants, so rely on the other variant for testing | 
 |                 # for now | 
 |                 if is_lambda(variant): | 
 |                     continue | 
 |  | 
 |                 tested = True | 
 |                 try: | 
 |                     self.indiv_variant_test_jit( | 
 |                         device, dtype, op, sample, func_type, variant, has_fake_function | 
 |                     ) | 
 |                 except Exception as e: | 
 |                     variant_error_info = dedent( | 
 |                         f""" | 
 |                         Error testing {op.name} {func_type} variant | 
 |                         with dtype: {dtype} | 
 |                         with inputs {sample}: | 
 |                     """ | 
 |                     ) | 
 |                     raise Exception(variant_error_info) from e  # noqa: TRY002 | 
 |  | 
 |         assert tested, "JIT Test does not execute any logic" | 
 |  | 
 |     def indiv_variant_test_jit( | 
 |         self, device, dtype, op, sample, func_type, variant, has_fake_function | 
 |     ): | 
 |         _requires_grad = dtype in op.supported_backward_dtypes( | 
 |             torch.device(device).type | 
 |         ) | 
 |         support_script = op.supports_scripting | 
 |         # Create accessor for script function variant | 
 |         name = op.name + "_" if func_type == "inplace" else op.name | 
 |  | 
 |         # run with disable_autodiff_subgraph_inlining(True) to test | 
 |         #   autodiff support. Context manager forces the graph to contain | 
 |         #   DifferentiableGraph nodes if they are present | 
 |         with disable_autodiff_subgraph_inlining(): | 
 |             # Check scripted forward, grad, and grad grad | 
 |             if support_script: | 
 |                 script_fn = create_script_fn(self, name, func_type) | 
 |  | 
 |             def out_fn(output): | 
 |                 # Processes the output for autograd | 
 |                 if sample.output_process_fn_grad is not None: | 
 |                     return sample.output_process_fn_grad(output) | 
 |                 return output | 
 |  | 
 |             def get_sample(): | 
 |                 return ( | 
 |                     clone_input_helper(sample.input) | 
 |                     if op.name[-1] == "_" | 
 |                     else sample.input | 
 |                 ) | 
 |  | 
 |             if support_script: | 
 |                 check_against_reference( | 
 |                     self, | 
 |                     script_fn, | 
 |                     op.get_op(), | 
 |                     out_fn, | 
 |                     (get_sample(),) + sample.args, | 
 |                     sample.kwargs, | 
 |                     no_grad=not _requires_grad, | 
 |                     no_gradgrad=not op.supports_gradgrad, | 
 |                 ) | 
 |  | 
 |             # Check traced forward, grad, and grad grad | 
 |             # TODO: fix tracing here | 
 |             supports_tracing = op.supports_tracing and not has_fake_function | 
 |             if op.assert_jit_shape_analysis: | 
 |                 self.assertTrue(supports_tracing) | 
 |  | 
 |             if supports_tracing: | 
 |                 traced_fn = create_traced_fn(self, variant) | 
 |                 check_against_reference( | 
 |                     self, | 
 |                     traced_fn, | 
 |                     op.get_op(), | 
 |                     out_fn, | 
 |                     (get_sample(),) + sample.args, | 
 |                     sample.kwargs, | 
 |                     no_grad=not _requires_grad, | 
 |                     no_gradgrad=not op.supports_gradgrad, | 
 |                 ) | 
 |  | 
 |             # Check alias annotation schema for correctness (make | 
 |             #   sure inputs that aren't supposed to be modified aren't) | 
 |             # Note: only runs in float32 because schema isn't affected by dtype, | 
 |             #   so running it on all dtypes is would be excessive | 
 |             if dtype == torch.float32: | 
 |                 # TODO: no reason why we cant run this with tracing graph | 
 |                 if support_script and op.name != "rsub": | 
 |                     check_alias_annotation( | 
 |                         name, | 
 |                         (get_sample(),) + sample.args, | 
 |                         sample.kwargs, | 
 |                         func_type=func_type, | 
 |                         aten_name=op.aten_name, | 
 |                     ) | 
 |  | 
 |                 # TODO: use script graph as well | 
 |                 checked_shape_analysis = False | 
 |                 if supports_tracing: | 
 |                     out = variant(get_sample(), *sample.args, **sample.kwargs) | 
 |  | 
 |                     # right now, tuple of outputs and tensor output supported | 
 |                     # TODO: list of tensor outputs | 
 |                     tuple_of_tensors = isinstance(out, tuple) and all( | 
 |                         isinstance(elem, torch.Tensor) for elem in out | 
 |                     ) | 
 |  | 
 |                     if isinstance(out, torch.Tensor) or tuple_of_tensors: | 
 |                         if tuple_of_tensors: | 
 |                             sizes = [elem.size() for elem in out] | 
 |                         else: | 
 |                             sizes = out.size() | 
 |                         self.checkShapeAnalysis( | 
 |                             sizes, traced_fn.graph, op.assert_jit_shape_analysis | 
 |                         ) | 
 |                         checked_shape_analysis = True | 
 |                 if op.assert_jit_shape_analysis: | 
 |                     self.assertTrue(checked_shape_analysis) | 
 |  | 
 |             # Check autodifferentiation of nodes for traced and scripted graphs, only need to check once per sample | 
 |             if dtype is torch.float32: | 
 |                 # Sandcastle doesn't fuse nodes | 
 |                 if IS_SANDCASTLE: | 
 |                     # fusible nodes are expected to be found in FusionGroups in the DifferentiableGraphs | 
 |                     nonfusible_nodes = ( | 
 |                         op.autodiff_nonfusible_nodes + op.autodiff_fusible_nodes | 
 |                     ) | 
 |                     fusible_nodes = [] | 
 |                 else: | 
 |                     nonfusible_nodes = op.autodiff_nonfusible_nodes | 
 |                     fusible_nodes = op.autodiff_fusible_nodes | 
 |  | 
 |                 if supports_tracing: | 
 |                     self.assertAutodiffNode( | 
 |                         traced_fn.last_graph, | 
 |                         op.assert_autodiffed, | 
 |                         nonfusible_nodes, | 
 |                         fusible_nodes, | 
 |                     ) | 
 |                 if support_script: | 
 |                     self.assertAutodiffNode( | 
 |                         script_fn.last_graph, | 
 |                         op.assert_autodiffed, | 
 |                         nonfusible_nodes, | 
 |                         fusible_nodes, | 
 |                     ) | 
 |  | 
 |     # alias testing is only done with torch.float for the same reason | 
 |     _alias_ops = partial(ops, dtypes=OpDTypes.supported, allowed_dtypes=(torch.float,)) | 
 |  | 
 |     @_alias_ops(op for op in op_db if op.aliases) | 
 |     def test_jit_alias_remapping(self, device, dtype, op): | 
 |         # NOTE: only tests on first sample | 
 |         samples = op.sample_inputs(device, dtype, requires_grad=True) | 
 |         sample = first_sample(self, samples) | 
 |  | 
 |         # [Scripting Data Preparation] | 
 |         # Prepare data for test scripting | 
 |         # Below we prepare strings of args/kwargs with and without type annotations. | 
 |         # These strings are inserted into function template strings which is then torch scripted. | 
 |         # - args string is ["t0"] corresponding to the "input" tensor required by the op | 
 |         # - args_kw is the value of args and strings of kwargs used to call the op (without type annotations), for example, | 
 |         # ["to", "1.0", "(1,)", "True", "tensor(1.0)"] -> def fn(t0): return variant(t0, 1.0, (1,), True, tensor(1.0)) | 
 |         args = ["t0"] | 
 |  | 
 |         def quote_strs(v): | 
 |             if isinstance(v, str): | 
 |                 return f"'{v}'" | 
 |  | 
 |             return str(v) | 
 |  | 
 |         args_kw = ( | 
 |             args | 
 |             + [f"{v}" for v in sample.args] | 
 |             + [f"{k}={quote_strs(v)}" for k, v in sample.kwargs.items()] | 
 |         ) | 
 |  | 
 |         # Prepare data for test tracing | 
 |         sample_args_kwargs = () | 
 |         if len(sample.args) > 0: | 
 |             sample_args_kwargs += (sample.args,) | 
 |         if len(sample.kwargs) > 0: | 
 |             sample_args_kwargs += (sample.kwargs,) | 
 |  | 
 |         original_name = op.aten_name | 
 |         original_name_inplace = original_name + "_" | 
 |         expected_dtype = op(sample.input, *sample.args, **sample.kwargs).dtype | 
 |  | 
 |         for a_op in op.aliases: | 
 |             inplace = a_op.inplace_variant | 
 |             method_or_inplace = [a_op.inplace_variant, a_op.method_variant] | 
 |             variants = ( | 
 |                 v | 
 |                 for v in (a_op.op, a_op.method_variant, a_op.inplace_variant) | 
 |                 if v is not None | 
 |             ) | 
 |  | 
 |             # Test scripting: | 
 |             for variant in variants: | 
 |                 variant_name = variant.__name__ | 
 |                 op_name = original_name_inplace if variant is inplace else original_name | 
 |  | 
 |                 if variant in method_or_inplace: | 
 |                     fn_template = """ | 
 |                         def _fn(t0{c}): | 
 |                             return t0.{alias_name}({args_kw}) | 
 |                     """ | 
 |                     # remove the first input tensor | 
 |                     script = fn_template.format( | 
 |                         c=", " if len(args_kw[1:]) > 1 else "", | 
 |                         args_kw=", ".join(args_kw[1:]), | 
 |                         alias_name=variant_name, | 
 |                     ) | 
 |                 else: | 
 |                     fn_template = """ | 
 |                         def _fn({args}): | 
 |                             return variant({args_kw}) | 
 |                     """ | 
 |                     script = fn_template.format( | 
 |                         args=", ".join(args), | 
 |                         args_kw=", ".join(args_kw), | 
 |                     ) | 
 |  | 
 |                 # Required to avoid undefined value: tensor error in JIT | 
 |                 # compilation of the function template | 
 |                 script = script.replace("tensor(", "torch.tensor(") | 
 |  | 
 |                 scripted = torch.jit.CompilationUnit(script)._fn | 
 |  | 
 |                 if variant is inplace and not torch.can_cast(expected_dtype, dtype): | 
 |                     try: | 
 |                         inp = clone_input_helper(sample.input) | 
 |                         scripted(inp) | 
 |                     except Exception as e: | 
 |                         continue | 
 |                     self.fail( | 
 |                         "Inplace operation on integer tensor that should be promoted to float didn't fail!" | 
 |                     ) | 
 |  | 
 |                 inp = clone_input_helper(sample.input) | 
 |                 scripted(inp) | 
 |                 inp = clone_input_helper(sample.input) | 
 |                 graph = scripted.graph_for(inp) | 
 |                 FileCheck().check(op.aten_name).check_not(variant_name).run(graph) | 
 |  | 
 |             # Test tracing: | 
 |             for variant in variants: | 
 |                 variant_name = variant.__name__ | 
 |                 op_name = original_name_inplace if variant is inplace else original_name | 
 |  | 
 |                 def _fn(*sample_args, **sample_kwargs): | 
 |                     return variant(*sample_args, **sample_kwargs) | 
 |  | 
 |                 inp = (clone_input_helper(sample.input),) + sample_args_kwargs | 
 |                 traced = torch.jit.trace(_fn, *inp) | 
 |                 inp = (clone_input_helper(sample.input),) + sample_args_kwargs | 
 |                 traced(*inp) | 
 |                 inp = (clone_input_helper(sample.input),) + sample_args_kwargs | 
 |                 graph = traced.graph_for(*inp) | 
 |                 FileCheck().check(op_name).check_not(variant_name).run(graph) | 
 |  | 
 |  | 
 | instantiate_device_type_tests(TestJit, globals()) | 
 |  | 
 | if __name__ == "__main__": | 
 |     TestCase._default_dtype_check_enabled = True | 
 |     run_tests() |