|  | # Owner(s): ["module: primTorch"] | 
|  |  | 
|  | from collections import defaultdict | 
|  | from torch import Tensor | 
|  | import torch.autograd | 
|  | from torch.utils._python_dispatch import enable_torch_dispatch_mode | 
|  | from torch._decomp import decomposition_table | 
|  |  | 
|  | from torch.utils._pytree import tree_map, tree_flatten, tree_unflatten | 
|  | from torch.utils._mode_utils import no_dispatch | 
|  | from torch.testing._internal.common_utils import ( | 
|  | is_iterable_of_tensors, | 
|  | TestCase, | 
|  | skipIfCrossRef, | 
|  | suppress_warnings, | 
|  | TEST_WITH_ASAN, | 
|  | run_tests, | 
|  | ) | 
|  | from torch.testing._internal.common_device_type import ( | 
|  | onlyNativeDeviceTypes, | 
|  | ops, | 
|  | instantiate_device_type_tests, | 
|  | ) | 
|  | from torch.testing._internal.common_methods_invocations import op_db | 
|  |  | 
|  | import itertools | 
|  | import functools | 
|  | from functools import partial | 
|  | import unittest | 
|  |  | 
|  | aten = torch.ops.aten | 
|  |  | 
|  |  | 
|  | # TODO: this isn't going to work with non-aten namespaces | 
|  | def overload_to_aten_name(overload): | 
|  | return overload._schema.name.split("::")[1] | 
|  |  | 
|  |  | 
|  | # All operators that can have decomp tests | 
|  | decomposition_names = {overload_to_aten_name(k) for k in decomposition_table} | 
|  | _decomp_test_ops = [ | 
|  | op | 
|  | for op in op_db | 
|  | if op.aten_name in decomposition_names | 
|  | or op.aten_backward_name in decomposition_names | 
|  | ] | 
|  |  | 
|  |  | 
|  | def diff_arg(arg, requires_grad=True): | 
|  | def is_differentiable_arg(arg): | 
|  | if requires_grad: | 
|  | return arg.requires_grad | 
|  | else: | 
|  | return arg.is_floating_point() or arg.is_complex() | 
|  |  | 
|  | if is_iterable_of_tensors(arg): | 
|  | if all([is_differentiable_arg(a) for a in arg]): | 
|  | return True | 
|  | if all([not is_differentiable_arg(a) for a in arg]): | 
|  | return False | 
|  | raise RuntimeError("NYI: The test runner can't handle this") | 
|  | return isinstance(arg, Tensor) and is_differentiable_arg(arg) | 
|  |  | 
|  |  | 
|  | # Version of autograd.grad with some differences: | 
|  | #   - pytree inputs is allowed (but leaves of the pytree have to all | 
|  | #     be tensors) | 
|  | #   - if an input is not used as part of derivatives, we will return a | 
|  | #     zero-filled tensor for the result | 
|  | def _autograd_grad( | 
|  | outputs, inputs, grad_outputs=None, retain_graph=False, create_graph=True | 
|  | ): | 
|  | inputs, inputs_spec = tree_flatten(inputs) | 
|  | diff_inputs = tuple(inp for inp in inputs if inp.requires_grad) | 
|  | if grad_outputs is None: | 
|  | diff_outputs = tuple(out for out in outputs if out.requires_grad) | 
|  | else: | 
|  | diff_grad_outputs = [ | 
|  | (out, go) for out, go in zip(outputs, grad_outputs) if out.requires_grad | 
|  | ] | 
|  | if len(diff_grad_outputs) == 0: | 
|  | diff_outputs, grad_outputs = (), () | 
|  | else: | 
|  | diff_outputs, grad_outputs = zip(*diff_grad_outputs) | 
|  | grad_inputs = torch.autograd.grad( | 
|  | diff_outputs, | 
|  | diff_inputs, | 
|  | grad_outputs, | 
|  | retain_graph=retain_graph, | 
|  | create_graph=create_graph, | 
|  | allow_unused=True, | 
|  | ) | 
|  | result = [] | 
|  | grad_inputs_iter = iter(grad_inputs) | 
|  | for inp in inputs: | 
|  | if inp.requires_grad: | 
|  | grad_input = next(grad_inputs_iter) | 
|  | if grad_input is None: | 
|  | result.append(torch.zeros_like(inp)) | 
|  | else: | 
|  | result.append(grad_input) | 
|  | else: | 
|  | result.append(torch.zeros_like(inp)) | 
|  | return tree_unflatten(result, inputs_spec) | 
|  |  | 
|  |  | 
|  | def _as_tuple(val): | 
|  | if isinstance(val, tuple): | 
|  | return val | 
|  | return (val,) | 
|  |  | 
|  |  | 
|  | def ref_vjp_no_create(f, *primals): | 
|  | result = f(*primals) | 
|  |  | 
|  | def wrapped(cotangents): | 
|  | return _autograd_grad( | 
|  | _as_tuple(result), primals, _as_tuple(cotangents), create_graph=False | 
|  | ) | 
|  |  | 
|  | return result, wrapped | 
|  |  | 
|  |  | 
|  | dtype_precisions = { | 
|  | torch.float16: (0.001, 1e-5), | 
|  | torch.bfloat16: (0.016, 1e-4), | 
|  | torch.float32: (1.3e-6, 1e-5), | 
|  | torch.float64: (1e-7, 1e-7), | 
|  | torch.complex32: (0.001, 1e-5), | 
|  | torch.complex64: (1.3e-6, 1e-5), | 
|  | torch.complex128: (1e-7, 1e-7), | 
|  | } | 
|  | # Returns the "default" rtol and atol for comparing scalars or | 
|  | # tensors of the given dtypes. | 
|  |  | 
|  |  | 
|  | def _getDefaultRtolAndAtol(dtype0, dtype1): | 
|  | rtol = max( | 
|  | dtype_precisions.get(dtype0, (0, 0))[0], dtype_precisions.get(dtype1, (0, 0))[0] | 
|  | ) | 
|  | atol = max( | 
|  | dtype_precisions.get(dtype0, (0, 0))[1], dtype_precisions.get(dtype1, (0, 0))[1] | 
|  | ) | 
|  | return rtol, atol | 
|  |  | 
|  |  | 
|  | def op_assert_ref(test_case, op, test_dtype, i, orig, decomp, ref, args, kwargs): | 
|  | assert orig.dtype == decomp.dtype, f"{i} Operation:  {op}" | 
|  | if orig.numel() == 0 or decomp.numel() == 0: | 
|  | assert orig.numel() == decomp.numel() | 
|  | return | 
|  | assert orig.shape == decomp.shape, f"{i} Operation:  {op}" | 
|  | tol_table = { | 
|  | (torch.bfloat16, torch.ops.aten.native_layer_norm.default): 1e-5, | 
|  | (torch.float16, torch.ops.aten.native_layer_norm.default): 1e-5, | 
|  | (torch.bfloat16, torch.ops.aten.native_batch_norm.default): 1e-5, | 
|  | (torch.float16, torch.ops.aten.native_batch_norm.default): 1e-5, | 
|  | } | 
|  | if ref.is_floating_point(): | 
|  | orig_diff = (orig - ref).abs().max() | 
|  | decomp_diff = (decomp - ref).abs().max() | 
|  | atol = tol_table.get((test_dtype, op), 1e-7) | 
|  | if decomp_diff > orig_diff + atol: | 
|  | raise RuntimeError( | 
|  | f"Difference from float64 is larger with decomposition {op.__name__}" | 
|  | f" than original on output {i}. Original max diff: {orig_diff}, Decomp max diff: {decomp_diff}\n" | 
|  | f"atol = {atol}\n" | 
|  | f"args = {args}\n" | 
|  | f"kwargs = {kwargs}" | 
|  | ) | 
|  | else: | 
|  | test_case.assertEqual( | 
|  | orig, decomp, msg=f"{op.__name__}\nargs = {args}\nkwargs = {kwargs}" | 
|  | ) | 
|  |  | 
|  |  | 
|  | def op_assert_equal(test_case, op, test_dtype, orig, decomp, args, kwargs): | 
|  | test_case.assertEqual( | 
|  | orig.dtype, decomp.dtype, f"Operation: {op}, orig.dtype: {orig.dtype}, decomp.dtype: {decomp.dtype}, {args}, {kwargs}") | 
|  | # Before adding an entry to this table, make sure your decomposition is right :) | 
|  | tol_table = { | 
|  | # Due to strange epsilon behaviors, see https://github.com/pytorch/pytorch/issues/73161 | 
|  | (torch.float32, torch.ops.aten.native_layer_norm.default): (1e-3, 1e-3), | 
|  | (torch.float32, torch.ops.aten.native_layer_norm_backward.default): ( | 
|  | 1e-3, | 
|  | 1e-3, | 
|  | ), | 
|  | } | 
|  | if (test_dtype, op) in tol_table: | 
|  | rtol, atol = tol_table[(decomp.dtype, op)] | 
|  | else: | 
|  | rtol, atol = _getDefaultRtolAndAtol(orig.dtype, decomp.dtype) | 
|  |  | 
|  | test_case.assertEqual(orig, decomp, rtol=rtol, atol=atol, msg=f"{op.__name__}\nargs = {args}\nkwargs = {kwargs}") | 
|  |  | 
|  |  | 
|  | # Given f, returns an f' such that: | 
|  | # - f' takes only positional arguments | 
|  | # - All arguments to f' are floating-point Tensors | 
|  | # - All outputs of f' are floating-point Tensors | 
|  | def normalize_op_input_output2( | 
|  | f, args, kwargs, output_process_fn_grad=None, requires_grad=True | 
|  | ): | 
|  | flat_args, args_spec = tree_flatten(args) | 
|  | diff_argnums = tuple( | 
|  | i | 
|  | for i, arg in enumerate(flat_args) | 
|  | if diff_arg(arg, requires_grad=requires_grad) | 
|  | ) | 
|  | assert len(diff_argnums) > 0 | 
|  | primals = tuple(flat_args[i] for i in diff_argnums) | 
|  |  | 
|  | @functools.wraps(f) | 
|  | def wrapped(*primals): | 
|  | _args = list(flat_args) | 
|  | for num, arg in zip(diff_argnums, primals): | 
|  | _args[num] = arg | 
|  | _args = tree_unflatten(_args, args_spec) | 
|  | result = f(*_args, **kwargs) | 
|  | if output_process_fn_grad is not None: | 
|  | result = output_process_fn_grad(result) | 
|  | if isinstance(result, tuple): | 
|  | # TODO: Remove the following hack for namedtuples | 
|  | result = tuple(result) | 
|  | result = tuple( | 
|  | r | 
|  | for r in result | 
|  | if isinstance(r, Tensor) and (r.is_floating_point() or r.is_complex()) | 
|  | ) | 
|  | assert len(result) > 0 | 
|  | return result | 
|  |  | 
|  | return wrapped, primals | 
|  |  | 
|  |  | 
|  | # NB: This also upcasts dtype arguments | 
|  | # TODO: handle complex correctly | 
|  | def upcast_tensor(x, dtype=torch.float32): | 
|  | if isinstance(x, Tensor) and x.dtype.is_floating_point: | 
|  | return x.to(dtype=dtype) | 
|  | elif ( | 
|  | isinstance(x, torch.dtype) | 
|  | and x in [torch.float16, torch.bfloat16] | 
|  | ): | 
|  | return dtype | 
|  | else: | 
|  | return x | 
|  |  | 
|  |  | 
|  | def normalize_op_input_output(f, sample, requires_grad=True): | 
|  | args = tuple([sample.input] + list(sample.args)) | 
|  | return normalize_op_input_output2( | 
|  | f, | 
|  | args, | 
|  | sample.kwargs, | 
|  | sample.output_process_fn_grad, | 
|  | requires_grad=requires_grad, | 
|  | ) | 
|  |  | 
|  |  | 
|  | CROSS_REF_EXCLUDE_SET = { | 
|  | # CUBLAS_STATUS_NOT_SUPPORTED when calling | 
|  | # `cublasGemmStridedBatchedExFix(handle, opa, opb, (int)m, (int)n, (int)k, | 
|  | # (void*)&falpha, a, CUDA_R_16BF, (int)lda, stridea, b, CUDA_R_16BF, | 
|  | # (int)ldb, strideb, (void*)&fbeta, c, CUDA_R_16BF, (int)ldc, stridec, | 
|  | # (int)num_batches, CUDA_R_32F, CUBLAS_GEMM_DEFAULT_TENSOR_OP)` | 
|  | ("cuda", torch.bfloat16, "nn.functional.bilinear"), | 
|  | # randomness | 
|  | ("cuda", torch.float16, "nn.functional.dropout"), | 
|  | ("cuda", torch.bfloat16, "nn.functional.dropout"), | 
|  | ("cuda", torch.float64, "nn.functional.dropout"), | 
|  | ("cuda", torch.float32, "nn.functional.dropout"), | 
|  | # decomp has problem even with opmath | 
|  | # doesn't work | 
|  | ("cuda", torch.bfloat16, "nn.functional.embedding"), | 
|  |  | 
|  | } | 
|  |  | 
|  | all_decomposed = set() | 
|  | all_called = defaultdict(int) | 
|  |  | 
|  | # Helpful snippet for testing coverage | 
|  | """ | 
|  | import atexit | 
|  | def check_coverage(): | 
|  | print("missing coverage:") | 
|  | print("\n".join(map(str, decomposition_table.keys() - all_decomposed))) | 
|  | atexit.register(check_coverage) | 
|  | """ | 
|  |  | 
|  | # Helpful snippet for Horace to create his google sheet :) | 
|  | """ | 
|  | import atexit | 
|  | def dump_ops(): | 
|  | with open('run_ops.txt', 'w') as f, open('count_ops.txt', 'w') as g: | 
|  | for op, count in sorted(all_called.items(), key=lambda x: x[0].__name__): | 
|  | f.write(f'{op.__name__}\n') | 
|  | g.write(f'{count}\n') | 
|  | with open('run_decompositions.txt', 'w') as f: | 
|  | for op in sorted([i.__name__ for i in all_decomposed]): | 
|  | f.write(f'{op}\n') | 
|  |  | 
|  | atexit.register(dump_ops) | 
|  | """ | 
|  |  | 
|  |  | 
|  | def any_unsupported(args, kwargs): | 
|  | def test_unsupported(t): | 
|  | if type(t) is torch.Tensor or type(t) is torch.nn.Parameter: | 
|  | # These are all things that we haven't coded decompositions | 
|  | # to handle correctly.  Maybe they should. | 
|  | return any([ | 
|  | t.is_sparse_csr, t.is_sparse, t.is_mkldnn, t.is_quantized, | 
|  | t.is_nested, torch._is_functional_tensor(t), | 
|  | ]) | 
|  | elif torch.overrides.is_tensor_like(t): | 
|  | # Decompositions will generally change the behavior of Tensor-like | 
|  | # subclasses, so bypass tests in this case too | 
|  | return True | 
|  | else: | 
|  | return False | 
|  |  | 
|  | flat_args, _ = tree_flatten(args) | 
|  | flat_kwargs, _ = tree_flatten(kwargs) | 
|  | return any(test_unsupported(x) for x in itertools.chain(flat_args, flat_kwargs)) | 
|  |  | 
|  |  | 
|  | class TestDecomp(TestCase): | 
|  | longMessage = True | 
|  |  | 
|  | # NB: This actually overlaps with test_comprehensive, but it only | 
|  | # runs on things that are definitely decomposed so it's a lot faster | 
|  | # to run | 
|  | @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") | 
|  | @onlyNativeDeviceTypes | 
|  | @skipIfCrossRef | 
|  | @suppress_warnings | 
|  | @ops(_decomp_test_ops) | 
|  | def test_quick(self, device, dtype, op): | 
|  | self.do_cross_ref(device, dtype, op, run_all=False) | 
|  |  | 
|  | @unittest.skipIf(TEST_WITH_ASAN, "Skipped under ASAN") | 
|  | @onlyNativeDeviceTypes | 
|  | @skipIfCrossRef | 
|  | @suppress_warnings | 
|  | @ops(op_db) | 
|  | def test_comprehensive(self, device, dtype, op): | 
|  | self.do_cross_ref(device, dtype, op, run_all=True) | 
|  |  | 
|  | def do_cross_ref(self, device, dtype, op, *, run_all): | 
|  | if (torch.device(device).type, dtype, op.name) in CROSS_REF_EXCLUDE_SET or ( | 
|  | None, | 
|  | dtype, | 
|  | op.name, | 
|  | ) in CROSS_REF_EXCLUDE_SET: | 
|  | self.skipTest(f"{op.name} in {dtype} not supported") | 
|  |  | 
|  | test_dtype = dtype | 
|  |  | 
|  | # We check the correctness of each decomposition right after running it. | 
|  | # So, when we encounter a decomposition, we run the function normally, and | 
|  | # then run the decomposition, and ensure they're identical. | 
|  | called = set() | 
|  | decomposed = set() | 
|  |  | 
|  | saved_precision = self.precision | 
|  | saved_rel_tol = self.rel_tol | 
|  |  | 
|  | class DecompCrossRefMode(torch.Tensor): | 
|  | @classmethod | 
|  | def __torch_dispatch__(cls, func, types, args=(), kwargs=None): | 
|  | with no_dispatch(): | 
|  | return cls._torch_dispatch(func, types, args, kwargs) | 
|  |  | 
|  | @classmethod | 
|  | def _torch_dispatch(cls, func, types, args=(), kwargs=None): | 
|  | self.precision = saved_precision | 
|  | self.rel_tol = saved_rel_tol | 
|  |  | 
|  | called.add(func) | 
|  | all_called[func] += 1 | 
|  |  | 
|  | # Stuff we shouldn't bother testing | 
|  | # (TODO: remove detach from the decomp table?) | 
|  | if func not in decomposition_table or func in [ | 
|  | torch.ops.aten.detach.default | 
|  | ] or any_unsupported(args, kwargs): | 
|  | return func(*args, **kwargs) | 
|  |  | 
|  | decomposed.add(func) | 
|  | all_decomposed.add(func) | 
|  |  | 
|  | # We take 2 main strategies for verifying correctness/numerical stability of decompositions | 
|  | # The first one is simply tolerance checking between decomp_out and pytorch_out | 
|  | # However, for fp16/bf16 and reductions, this becomes very | 
|  | # finicky, as there are not many guarantees we can make. | 
|  | # So, for fp16/bf16, we instead compare the difference of | 
|  | # {decomp_out, pytorch_out_64} and {pytorch_out, | 
|  | # pytorch_out_64}. In other words, we compare how far the | 
|  | # decomposition and pytorch are from the "ground truth" (i.e. | 
|  | # fp64). If the decomposition results in more error, we error | 
|  |  | 
|  | decomposition = decomposition_table[func] | 
|  |  | 
|  | do_relative_check = test_dtype in [torch.float16, torch.bfloat16] | 
|  | real_out_unflat = func(*args, **kwargs) | 
|  | real_out, _ = tree_flatten(real_out_unflat) | 
|  | decomp_out, _ = tree_flatten(decomposition(*args, **kwargs)) | 
|  | assert len(real_out) == len(decomp_out) | 
|  |  | 
|  | if do_relative_check: | 
|  | upcast = partial(upcast_tensor, dtype=torch.float64) | 
|  | real_out_double, _ = tree_flatten( | 
|  | func(*tree_map(upcast, args), **tree_map(upcast, kwargs)) | 
|  | ) | 
|  | for i, orig, decomp, ref in zip(range(len(real_out)), real_out, decomp_out, real_out_double): | 
|  | if orig is None: | 
|  | assert decomp is None | 
|  | continue | 
|  | op_assert_ref(self, func, test_dtype, i, orig, decomp, ref, args, kwargs) | 
|  | else: | 
|  | for orig, decomp in zip(real_out, decomp_out): | 
|  | if orig is None: | 
|  | assert decomp is None | 
|  | continue | 
|  | op_assert_equal(self, func, test_dtype, orig, decomp, args, kwargs) | 
|  |  | 
|  | return real_out_unflat | 
|  |  | 
|  | requires_grad = ( | 
|  | op.supports_autograd | 
|  | and dtype in op.supported_backward_dtypes(torch.device(device).type) | 
|  | # TODO: OpInfo really ought to error out for this case, but it's | 
|  | # not exercised in test_ops_gradients atm.  The problem is not | 
|  | # complex32 per-se (which is supported by data movement only ops) | 
|  | # but that when we do backwards we expect other ops like add to work | 
|  | and not dtype == torch.complex32 | 
|  | ) | 
|  | samples = op.sample_inputs(device, test_dtype, requires_grad=requires_grad) | 
|  |  | 
|  | def check_decomposed(aten_name): | 
|  | self.assertTrue( | 
|  | any(overload_to_aten_name(c) == aten_name for c in decomposed), | 
|  | msg=f"aten.{aten_name} was not decomposed, saw calls for: " | 
|  | + ", ".join(map(str, list(called))), | 
|  | ) | 
|  |  | 
|  | aten_name = op.decomp_aten_name or op.aten_name | 
|  |  | 
|  | func = op.get_op() | 
|  | for sample_input in samples: | 
|  | if requires_grad: | 
|  | if None in sample_input.args: | 
|  | continue | 
|  |  | 
|  | fn, primals = normalize_op_input_output(func, sample_input) | 
|  | primals = tree_map( | 
|  | lambda x: x if isinstance(x, torch.Tensor) else x, primals | 
|  | ) | 
|  |  | 
|  | # Once https://github.com/pytorch/pytorch/pull/75965/ I can | 
|  | # store the called list on the mode object instance and no | 
|  | # explicit clearing is necessary as I will create a fresh mode | 
|  | # for each region | 
|  | decomposed.clear() | 
|  | with enable_torch_dispatch_mode(DecompCrossRefMode): | 
|  | decomp_out, decomp_vjp_fn = ref_vjp_no_create(fn, *primals) | 
|  | if aten_name in decomposition_names: | 
|  | check_decomposed(aten_name) | 
|  |  | 
|  | if op.aten_backward_name in decomposition_names or run_all: | 
|  | cotangents = tree_map(lambda x: torch.randn_like(x), decomp_out) | 
|  |  | 
|  | decomposed.clear() | 
|  | with enable_torch_dispatch_mode(DecompCrossRefMode): | 
|  | decomp_vjp_fn(cotangents) | 
|  | if not run_all: | 
|  | check_decomposed(op.aten_backward_name) | 
|  |  | 
|  | elif aten_name in decomposition_names or run_all: | 
|  | args = [sample_input.input] + list(sample_input.args) | 
|  | kwargs = sample_input.kwargs | 
|  | decomposed.clear() | 
|  | with enable_torch_dispatch_mode(DecompCrossRefMode): | 
|  | func(*args, **kwargs) | 
|  | if not run_all: | 
|  | check_decomposed(aten_name) | 
|  | else: | 
|  | assert op.supports_autograd | 
|  | self.skipTest( | 
|  | "only backwards is decomposed, but dtype doesn't support AD" | 
|  | ) | 
|  |  | 
|  |  | 
|  | instantiate_device_type_tests(TestDecomp, globals()) | 
|  |  | 
|  | if __name__ == "__main__": | 
|  | run_tests() |