| # Owner(s): ["module: nn"] | 
 | from dataclasses import dataclass | 
 | from functools import partial | 
 | from itertools import product, chain | 
 | import unittest | 
 |  | 
 | import torch | 
 | import torch.nn as nn | 
 | import torch.nn.functional as F | 
 | from torch.nn import CrossEntropyLoss | 
 | from torch.nn.utils._per_sample_grad import call_for_per_sample_grads | 
 | from torch.testing._internal.common_cuda import TEST_CUDA, tf32_off | 
 | from torch.testing._internal.common_device_type import OpDTypes, instantiate_device_type_tests, ops | 
 | from torch.testing._internal.common_modules import module_db, modules | 
 | from torch.testing._internal.common_nn import TestBase, module_tests, new_module_tests | 
 | from torch.testing._internal.common_utils import TestCase, freeze_rng_state, make_tensor, run_tests, parametrize, skipIfTorchDynamo | 
 | from torch.testing._internal.common_methods_invocations import SampleInput, op_db | 
 | from torch.nn.utils._expanded_weights import ExpandedWeight | 
 | from torch.nn.utils._expanded_weights.expanded_weights_utils import forward_helper, set_grad_sample_if_exists, \ | 
 |     unpack_expanded_weight_or_tensor, sum_over_all_but_batch_and_last_n, standard_kwargs | 
 | from torch.utils._pytree import tree_map_only | 
 |  | 
 |  | 
 | class TestContext: | 
 |     pass | 
 |  | 
 | class TestExpandedWeightHelperFunction(TestCase): | 
 |     def test_forward_helper(self, device): | 
 |         input = torch.randn(3, 4, device=device) | 
 |         weight = torch.randn(5, 4, device=device) | 
 |         bias = torch.randn(5, device=device) | 
 |         for (weight_batched, bias_batched) in product([True, False], [True, False]): | 
 |             maybe_batched_weight = weight | 
 |             maybe_batched_bias = bias | 
 |             if weight_batched: | 
 |                 maybe_batched_weight = ExpandedWeight(weight.clone().requires_grad_(), 3, loss_reduction="sum") | 
 |             if bias_batched: | 
 |                 maybe_batched_bias = ExpandedWeight(bias.clone().requires_grad_(), 3, loss_reduction="sum") | 
 |             args = (input, maybe_batched_weight, maybe_batched_bias) | 
 |             expanded_args, expanded_kwargs = standard_kwargs(('bias',), args) | 
 |             res = forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) | 
 |             expected = nn.functional.linear(input, weight, bias) | 
 |             self.assertEqual(res, expected) | 
 |  | 
 |             self.assertEqual(len(expanded_args), 2) | 
 |             assert expanded_args[0] is args[0]  # avoids property checks in assertEquals | 
 |             assert expanded_args[1] is args[1]  # avoids property checks in assertEquals | 
 |             self.assertEqual(len(expanded_kwargs), 1) | 
 |             assert expanded_kwargs['bias'] is args[2]  # avoids property checks in assertEquals | 
 |  | 
 |     def test_forward_helper_failure_args(self, device): | 
 |         weight = torch.randn(5, 4, device=device) | 
 |         bias = torch.randn(5, device=device) | 
 |         with self.assertRaisesRegex(RuntimeError, r"do not support inputs that are also ExpandedWeights."): | 
 |             input = ExpandedWeight(torch.randn(3, 4, requires_grad=True), 3, loss_reduction="sum") | 
 |             expanded_args, expanded_kwargs = standard_kwargs(('bias',), (input, weight, bias)) | 
 |             forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) | 
 |         with self.assertRaisesRegex(RuntimeError, r"requires a Tensor as the first input"): | 
 |             expanded_args, expanded_kwargs = standard_kwargs(('bias',), (3, weight, bias)) | 
 |             forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) | 
 |         with self.assertRaisesRegex(RuntimeError, r"requires a batch dimension but got an input of size 0"): | 
 |             expanded_args, expanded_kwargs = standard_kwargs(('bias',), (torch.tensor(3), weight, bias)) | 
 |             forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) | 
 |         with self.assertRaisesRegex(RuntimeError, r"0 is not a valid batch size for Expanded Weights"): | 
 |             expanded_args, expanded_kwargs = standard_kwargs(('bias',), (torch.randn(0, 1, 2), weight, bias)) | 
 |             forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) | 
 |         input = torch.randn(3, 4) | 
 |         for (weight_batched, bias_batched) in product([True, False], [True, False]): | 
 |             if not weight_batched and not bias_batched: | 
 |                 continue | 
 |             maybe_batched_weight = weight | 
 |             maybe_batched_bias = bias | 
 |             if weight_batched: | 
 |                 maybe_batched_weight = ExpandedWeight(weight.clone().requires_grad_(), 4, loss_reduction="sum") | 
 |             if bias_batched: | 
 |                 maybe_batched_bias = ExpandedWeight(bias.clone().requires_grad_(), 4, loss_reduction="sum") | 
 |             with self.assertRaisesRegex(RuntimeError, r"Expected ExpandedWeights to have batch size matching input"): | 
 |                 expanded_args, expanded_kwargs = standard_kwargs(('bias',), (input, maybe_batched_weight, maybe_batched_bias)) | 
 |                 forward_helper(nn.functional.linear, expanded_args, expanded_kwargs) | 
 |  | 
 |     def test_set_grad_sample_if_exists(self, device): | 
 |         def test_fn(a): | 
 |             return grad_sample | 
 |  | 
 |         orig_weight = torch.randn(4, device=device, requires_grad=True) | 
 |         expanded_weight = ExpandedWeight(orig_weight, 3, loss_reduction="sum") | 
 |         grad_sample = torch.randn(3) | 
 |         set_grad_sample_if_exists(expanded_weight, test_fn) | 
 |         self.assertTrue(hasattr(orig_weight, 'grad_sample')) | 
 |         self.assertEqual(orig_weight.grad_sample, grad_sample) | 
 |  | 
 |         basic_tensor = torch.randn(4, device=device) | 
 |         set_grad_sample_if_exists(basic_tensor, test_fn) | 
 |         self.assertFalse(hasattr(basic_tensor, 'grad_sample')) | 
 |  | 
 |         non_tensor = 3 | 
 |         set_grad_sample_if_exists(non_tensor, test_fn) | 
 |         self.assertFalse(hasattr(non_tensor, 'grad_sample')) | 
 |  | 
 |     def test_set_grad_sample_if_exists_failure(self, device): | 
 |         def test_fn(a): | 
 |             return True | 
 |  | 
 |         grad_tensor = torch.randn(4, requires_grad=True, device=device) | 
 |         with self.assertRaisesRegex(RuntimeError, r"does not support a mixture of ExpandedWeight parameters and normal Parameters"): | 
 |             set_grad_sample_if_exists(grad_tensor, test_fn) | 
 |  | 
 |     def test_unpack_expanded_weight_or_tensor(self, device): | 
 |         input = torch.randn(3, requires_grad=True, device=device) | 
 |         self.assertEqual(input, unpack_expanded_weight_or_tensor(ExpandedWeight(input, 3, loss_reduction="sum"))) | 
 |  | 
 |         input.requires_grad_(False) | 
 |         self.assertEqual(input, unpack_expanded_weight_or_tensor(input)) | 
 |         self.assertTrue(unpack_expanded_weight_or_tensor(4) is None) | 
 |  | 
 |     def test_unpack_expanded_weight_or_tensor_with_custom_function(self, device): | 
 |         input = torch.randn(3, requires_grad=True, device=device) | 
 |         self.assertTrue(unpack_expanded_weight_or_tensor(ExpandedWeight(input, 3, loss_reduction="sum"), lambda x: x is input)) | 
 |  | 
 |         input.requires_grad_(False) | 
 |         self.assertTrue(unpack_expanded_weight_or_tensor(input, lambda x: x is input)) | 
 |         self.assertTrue(unpack_expanded_weight_or_tensor(4, lambda x: x is input) is None) | 
 |  | 
 |     def test_unpack_expanded_weight_or_tensor_failure(self, device): | 
 |         input = torch.randn(3, requires_grad=True, device=device) | 
 |         with self.assertRaisesRegex(RuntimeError, r"does not support a mixture of ExpandedWeight parameters and normal Parameters"): | 
 |             unpack_expanded_weight_or_tensor(input) | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, r"does not support a mixture of ExpandedWeight parameters and normal Parameters"): | 
 |             unpack_expanded_weight_or_tensor(input, lambda x: x is input) | 
 |  | 
 |     def test_sum_over_all_but_batch_and_last_n(self, device): | 
 |         input = torch.randn(1, 2, 3, 4, 5, device=device) | 
 |         res = sum_over_all_but_batch_and_last_n(input, 2) | 
 |         expected = input.sum((1, 2)) | 
 |         self.assertEqual(res, expected) | 
 |  | 
 |         res = sum_over_all_but_batch_and_last_n(input, 0) | 
 |         expected = input.sum((1, 2, 3, 4)) | 
 |         self.assertEqual(res, expected) | 
 |  | 
 |         res = sum_over_all_but_batch_and_last_n(input, 4) | 
 |         self.assertEqual(res, input) | 
 |  | 
 | class TestExpandedWeightFunctional(TestCase): | 
 |     def _compare_ew_and_for_loop_per_sample_grads(self, op, sample_input, reduction): | 
 |         input = sample_input.input | 
 |         args = sample_input.args | 
 |         kwargs = sample_input.kwargs | 
 |         batch_size = input.shape[0] if len(input.shape) > 1 else 1 | 
 |  | 
 |         # get per sample grads with ExpandedWeights objects | 
 |         loss_reduction = "sum" if reduction == torch.sum else "mean" | 
 |         (ew_input, ew_args, ew_kwargs) = make_expanded_weight(sample_input, batch_size, loss_reduction) | 
 |         diff_input_list = (ew_input,) + tuple(ew_args) + tuple(ew_kwargs.values()) | 
 |         diff_input_list = [i for i in diff_input_list if is_diff_tensor(i)] | 
 |         diff_input_list = [i.orig_weight if isinstance(i, ExpandedWeight) else i for i in diff_input_list] | 
 |         if not diff_input_list: | 
 |             return | 
 |         result = run_op(op, ew_input, *ew_args, **ew_kwargs) | 
 |         reduction(result).backward()  # grad doesn't work with ExpandedWeight because it calls __torch_function__ | 
 |         expanded_weight_grad = tuple(i.grad_sample if hasattr(i, "grad_sample") else i.grad for i in diff_input_list) | 
 |  | 
 |         # get per sample grads with for loop | 
 |         func = partial(run_op, op) | 
 |  | 
 |         per_sample_grad = for_loop_per_sample_grad(batch_size, reduction, input, func, *args, **kwargs) | 
 |  | 
 |         # check equality | 
 |         self.assertEqual(len(per_sample_grad), len(expanded_weight_grad)) | 
 |         if loss_reduction == "mean": | 
 |             # don't check equality of `input.grad`s since these vanilla tensors won't be scaled | 
 |             expanded_weight_grad = expanded_weight_grad[1:] | 
 |             per_sample_grad = per_sample_grad[1:] | 
 |         for (result_grad, expected_grad) in zip(expanded_weight_grad, per_sample_grad): | 
 |             self.assertEqual(result_grad, expected_grad) | 
 |  | 
 |     @ops(filter(lambda op: op.supports_expanded_weight, op_db), dtypes=OpDTypes.supported, allowed_dtypes=(torch.double,)) | 
 |     def test_expanded_weight_per_sample_grad_sum(self, device, dtype, op): | 
 |         sample_inputs = op.sample_inputs(device, dtype, requires_grad=True) | 
 |         for sample_input in supported_inputs(op, sample_inputs): | 
 |             if op.name == "nn.functional.embedding":  # embedding flips its argument order for autograd tests | 
 |                 sample_input = SampleInput(sample_input.args[0], args=(sample_input.input,), kwargs=sample_input.kwargs) | 
 |  | 
 |             self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.sum) | 
 |  | 
 |     @ops(filter(lambda op: op.supports_expanded_weight, op_db), dtypes=OpDTypes.supported, allowed_dtypes=(torch.double,)) | 
 |     def test_expanded_weight_per_sample_grad_mean(self, device, dtype, op): | 
 |         sample_inputs = op.sample_inputs(device, dtype, requires_grad=True) | 
 |         for sample_input in supported_inputs(op, sample_inputs): | 
 |             if op.name == "nn.functional.embedding":  # embedding flips its argument order for autograd tests | 
 |                 sample_input = SampleInput(sample_input.args[0], args=(sample_input.input,), kwargs=sample_input.kwargs) | 
 |  | 
 |             self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.mean) | 
 |  | 
 |     @ops(filter(lambda op: op.supports_expanded_weight, op_db), dtypes=OpDTypes.supported, allowed_dtypes=(torch.double,)) | 
 |     def test_expanded_weights_per_sample_grad_input_no_grad(self, device, dtype, op): | 
 |         sample_inputs = op.sample_inputs(device, dtype, requires_grad=True) | 
 |         for sample_input in supported_inputs(op, sample_inputs): | 
 |             if op.name == "nn.functional.embedding":  # embedding flips its argument order for autograd tests | 
 |                 sample_input = SampleInput(sample_input.args[0], args=(sample_input.input,), kwargs=sample_input.kwargs) | 
 |             sample_input.input.requires_grad_(False) | 
 |  | 
 |             self._compare_ew_and_for_loop_per_sample_grads(op, sample_input, torch.mean) | 
 |  | 
 |     @skipIfTorchDynamo("Checking error message doesn't work with dynamo") | 
 |     @ops(filter(lambda op: op.supports_expanded_weight, op_db), dtypes=OpDTypes.supported, allowed_dtypes=(torch.double,)) | 
 |     def test_unsupported_expand_weights(self, device, dtype, op): | 
 |         sample_inputs = op.sample_inputs(device, dtype, requires_grad=True) | 
 |         unsupported_inputs = supported_inputs(op, sample_inputs, supported_inputs=False) | 
 |         for sample_input in unsupported_inputs: | 
 |             with self.assertRaisesRegex(RuntimeError, r"Expanded Weights"): | 
 |                 if op.name == "nn.functional.embedding":  # embedding flips its argument order for autograd tests | 
 |                     sample_input = SampleInput(sample_input.args[0], args=(sample_input.input,), kwargs=sample_input.kwargs) | 
 |                 input = sample_input.input | 
 |  | 
 |                 batch_size = input.shape[0] if len(input.shape) > 1 else 1 | 
 |  | 
 |                 # get per sample grads with ExpandedWeights objects | 
 |                 (ew_input, ew_args, ew_kwargs) = make_expanded_weight(sample_input, batch_size) | 
 |                 result = run_op(op, ew_input, *ew_args, **ew_kwargs) | 
 |                 diff_input_list = (ew_input,) + tuple(ew_args) + tuple(ew_kwargs.values()) | 
 |                 diff_input_list = [i for i in diff_input_list if is_diff_tensor(i)] | 
 |                 diff_input_list = [i.orig_weight if isinstance(i, ExpandedWeight) else i for i in diff_input_list] | 
 |                 result.sum().backward()  # grad doesn't work with ExpandedWeight because it calls __torch_function__ | 
 |  | 
 |     @ops(filter(lambda op: op.supports_expanded_weight, op_db), dtypes=OpDTypes.supported) | 
 |     def test_expanded_weight_forward(self, device, dtype, op): | 
 |         sample_inputs = op.sample_inputs(device, dtype) | 
 |         for sample_input in supported_inputs(op, sample_inputs): | 
 |             if op.name == "nn.functional.embedding":  # embedding flips its argument order for autograd tests | 
 |                 sample_input = SampleInput(sample_input.args[0].clone(), | 
 |                                            args=(sample_input.input.clone(),), | 
 |                                            kwargs=sample_input.kwargs) | 
 |                 if "cuda" in device and "max_norm" in sample_input.kwargs and "padding_idx" in sample_input.kwargs: | 
 |                     self.skipTest("embedding is non-determinstic in this case, see issue #74679") | 
 |             batch_size = sample_input.input.shape[0] if len(sample_input.input.shape) > 1 else 1 | 
 |             for loss_reduction in ["sum", "mean"]: | 
 |                 (ew_input, ew_args, ew_kwargs) = make_expanded_weight(sample_input, batch_size, loss_reduction) | 
 |                 expanded_weight_result = run_op(op, ew_input, *ew_args, **ew_kwargs) | 
 |                 normal_result = run_op(op, sample_input.input, *sample_input.args, **sample_input.kwargs) | 
 |                 self.assertEqual(expanded_weight_result, normal_result) | 
 |  | 
 |     def test_expanded_weight_error(self, device): | 
 |         batch_size = 3 | 
 |         sample_input = make_tensor((batch_size, 4), dtype=torch.float32, device=device, requires_grad=True) | 
 |         sample_weight = make_tensor((4), dtype=torch.float32, device=device, requires_grad=True) | 
 |         with self.assertRaisesRegex(RuntimeError, r"Expanded Weights encountered but cannot handle function"): | 
 |             torch.add(sample_input, ExpandedWeight(sample_weight, batch_size, loss_reduction="sum")) | 
 |  | 
 |     def _test_embedding_model(self, model, num_embedding, device): | 
 |         batch_size = 32 | 
 |         input = torch.randint(0, num_embedding, (batch_size, 5, 5), device=device) | 
 |         return self._test_model(partial(model, num_embedding=num_embedding), batch_size, input, device) | 
 |  | 
 |     def _test_conv_model(self, model, input_size, num_dim, device, loss_reduction="sum", atol=1e-4, rtol=5e-5): | 
 |         batch_size = 32 | 
 |         input_ending = [input_size] * num_dim | 
 |         input = torch.randn([batch_size, 3] + input_ending, device=device) | 
 |         return self._test_model(partial(model, num_dim=num_dim), batch_size, input, device, loss_reduction, atol, rtol) | 
 |  | 
 |     def _test_model(self, model, batch_size, input, device, loss_reduction="sum", atol=1e-4, rtol=5e-5): | 
 |         model = model(10).to(device) | 
 |         targets = torch.randint(0, 10, (batch_size,), device=device) | 
 |         criterion = CrossEntropyLoss(reduction=loss_reduction) | 
 |         result = call_for_per_sample_grads(model, loss_reduction=loss_reduction)(input) | 
 |         loss = criterion(result, targets) | 
 |         loss.backward() | 
 |         result = [] | 
 |         for weight in model.parameters(): | 
 |             result.append(weight.grad_sample) | 
 |             del weight.grad_sample | 
 |  | 
 |         expected = [] | 
 |         for i in range(batch_size): | 
 |             loss = criterion(model(input[i].unsqueeze(0)), targets[i].unsqueeze(0)) | 
 |             expected.append(torch.autograd.grad(loss, model.parameters(), torch.ones_like(loss))) | 
 |  | 
 |         expected = [torch.stack(grad) for grad in zip(*expected)] | 
 |         for (res, exp) in zip(result, expected): | 
 |             self.assertEqual(res, exp, atol=atol, rtol=rtol) | 
 |  | 
 |     def _compute_tolerances(self, device): | 
 |         is_cuda_sm86 = device.startswith("cuda") and torch.cuda.get_device_capability(0) == (8, 6) | 
 |         return (9e-3, 5e-5) if is_cuda_sm86 else (1e-4, 5e-5) | 
 |  | 
 |     @tf32_off() | 
 |     def test_cnn_model_sum(self, device): | 
 |         def convnet(num_classes, num_dim): | 
 |             return nn.Sequential( | 
 |                 nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1), | 
 |                 nn.ReLU(), | 
 |                 nn.AvgPool2d(kernel_size=2, stride=2), | 
 |                 nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), | 
 |                 nn.ReLU(), | 
 |                 nn.AvgPool2d(kernel_size=2, stride=2), | 
 |                 nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), | 
 |                 nn.ReLU(), | 
 |                 nn.AvgPool2d(kernel_size=2, stride=2), | 
 |                 nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), | 
 |                 nn.ReLU(), | 
 |                 nn.AdaptiveAvgPool2d((1, 1)), | 
 |                 nn.Flatten(start_dim=1, end_dim=-1), | 
 |                 nn.Linear(128, num_classes, bias=True), | 
 |             ) | 
 |  | 
 |         atol, rtol = self._compute_tolerances(device) | 
 |         return self._test_conv_model(convnet, 28, 2, device, atol=atol, rtol=rtol) | 
 |  | 
 |     @tf32_off() | 
 |     def test_cnn_model_mean(self, device): | 
 |         def convnet(num_classes, num_dim): | 
 |             return nn.Sequential( | 
 |                 nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1), | 
 |                 nn.ReLU(), | 
 |                 nn.AvgPool2d(kernel_size=2, stride=2), | 
 |                 nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1), | 
 |                 nn.ReLU(), | 
 |                 nn.AvgPool2d(kernel_size=2, stride=2), | 
 |                 nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1), | 
 |                 nn.ReLU(), | 
 |                 nn.AvgPool2d(kernel_size=2, stride=2), | 
 |                 nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1), | 
 |                 nn.ReLU(), | 
 |                 nn.AdaptiveAvgPool2d((1, 1)), | 
 |                 nn.Flatten(start_dim=1, end_dim=-1), | 
 |                 nn.Linear(128, num_classes, bias=True), | 
 |             ) | 
 |         atol, rtol = self._compute_tolerances(device) | 
 |         return self._test_conv_model(convnet, 28, 2, device, loss_reduction="mean", atol=atol, rtol=rtol) | 
 |  | 
 |     @parametrize('num_dim', [1, 2, 3]) | 
 |     @tf32_off() | 
 |     def test_instance_norm_model(self, num_dim, device): | 
 |         def instance_norm_model(num_classes, num_dim): | 
 |             conv_layer = nn.Conv1d if num_dim == 1 else nn.Conv2d if num_dim == 2 else nn.Conv3d | 
 |             norm_layer = nn.InstanceNorm1d if num_dim == 1 else nn.InstanceNorm2d if num_dim == 2 else nn.InstanceNorm3d | 
 |             return nn.Sequential( | 
 |                 conv_layer(3, 32, kernel_size=3, stride=1, padding=1), | 
 |                 norm_layer(32, affine=True), | 
 |                 nn.Flatten(start_dim=1, end_dim=-1), | 
 |                 nn.Linear(32 * (7 ** num_dim), num_classes, bias=True), | 
 |             ) | 
 |         atol, rtol = self._compute_tolerances(device) | 
 |         return self._test_conv_model(instance_norm_model, 7, num_dim, device, atol=atol, rtol=rtol) | 
 |  | 
 |     @parametrize('num_dim', [1, 2, 3]) | 
 |     @tf32_off() | 
 |     def test_group_norm_model(self, num_dim, device): | 
 |         def group_norm_model(num_classes, num_dim): | 
 |             conv_layer = nn.Conv1d if num_dim == 1 else nn.Conv2d if num_dim == 2 else nn.Conv3d | 
 |             return nn.Sequential( | 
 |                 conv_layer(3, 32, kernel_size=3, stride=1, padding=1), | 
 |                 nn.GroupNorm(8, 32, affine=True), | 
 |                 nn.Flatten(start_dim=1, end_dim=-1), | 
 |                 nn.Linear(32 * (7 ** num_dim), num_classes, bias=True), | 
 |             ) | 
 |         atol, rtol = self._compute_tolerances(device) | 
 |         return self._test_conv_model(group_norm_model, 7, num_dim, device, atol=atol, rtol=rtol) | 
 |  | 
 |     @parametrize('num_dim', [1, 2, 3]) | 
 |     @tf32_off() | 
 |     def test_layer_norm_model(self, num_dim, device): | 
 |         def layer_norm_model(num_classes, num_dim): | 
 |             conv_layer = nn.Conv1d if num_dim == 1 else nn.Conv2d if num_dim == 2 else nn.Conv3d | 
 |             normalized_shape = [7] * num_dim | 
 |             return nn.Sequential( | 
 |                 conv_layer(3, 32, kernel_size=3, stride=1, padding=1), | 
 |                 nn.LayerNorm(normalized_shape, elementwise_affine=True), | 
 |                 nn.Flatten(start_dim=1, end_dim=-1), | 
 |                 nn.Linear(32 * (7 ** num_dim), num_classes, bias=True), | 
 |             ) | 
 |         atol, rtol = self._compute_tolerances(device) | 
 |         return self._test_conv_model(layer_norm_model, 7, num_dim, device, atol=atol, rtol=rtol) | 
 |  | 
 |     def test_embedding_model(self, device): | 
 |         def embedding_model(num_classes, num_embedding): | 
 |             return nn.Sequential( | 
 |                 nn.Embedding(num_embedding, 15), | 
 |                 nn.Flatten(start_dim=1, end_dim=-1), | 
 |                 nn.Linear(375, num_classes, bias=True) | 
 |             ) | 
 |         return self._test_embedding_model(embedding_model, 16, device) | 
 |  | 
 |     def test_group_norm_error(self, device): | 
 |         # group norm has to call native_group_norm. This checks that it hits the same errors | 
 |         # that normal group norm would | 
 |  | 
 |         N = 3 | 
 |         C = 5 | 
 |         inp = torch.randn(N, C) | 
 |         with self.assertRaisesRegex(RuntimeError, r"Expected number of channels in input to be divisible"): | 
 |             F.group_norm(inp, 2)  # 5 is not divisible by 2 | 
 |  | 
 | class TestExpandedWeightModule(TestCase): | 
 |     def _do_test(self, module, input, args=None, kwargs=None, batch_first=True, atol=None, rtol=None): | 
 |         args = args or () | 
 |         kwargs = kwargs or {} | 
 |  | 
 |         batch_dim = 0 if batch_first else 1 | 
 |         batch_size = input.shape[batch_dim] | 
 |         diff_input = input.dtype == torch.float or input.dtype == torch.double | 
 |         if diff_input: | 
 |             input.requires_grad_() | 
 |  | 
 |         with freeze_rng_state(): | 
 |             # get per sample grads with ExpandedWeights context manager | 
 |             actual_res = call_for_per_sample_grads(module, | 
 |                                                    batch_size=batch_size, | 
 |                                                    loss_reduction="sum", | 
 |                                                    batch_first=batch_first)(input, *args, **kwargs).sum() | 
 |             actual_res.backward() | 
 |             actual_grads = [] | 
 |             for param in module.parameters(): | 
 |                 actual_grads.append(param.grad_sample) | 
 |                 del param.grad_sample | 
 |             if diff_input: | 
 |                 actual_grads.append(input.grad.clone()) | 
 |                 input.grad = torch.zeros_like(input.grad) | 
 |  | 
 |             # get per sample grads with a for loop | 
 |             expected_res = torch.tensor(0., device=input.device, dtype=actual_res.dtype) | 
 |             expected_grads = [] | 
 |             for i in range(batch_size): | 
 |                 input_slice = input.narrow(batch_dim, i, 1) | 
 |                 input_slice = input_slice.squeeze(batch_dim) | 
 |  | 
 |                 # h's batch dim is always the first dim. Must be contiguous for CUDA | 
 |                 sliced_args = tree_map_only(torch.Tensor, lambda t: t.narrow(1, i, 1).contiguous(), args) | 
 |                 diff_params = module.parameters() | 
 |                 if diff_input: | 
 |                     diff_params = chain(diff_params, (input_slice,)) | 
 |                 res = module(input_slice.unsqueeze(batch_dim).contiguous(), *sliced_args, **kwargs).sum() | 
 |                 out_grads = torch.autograd.grad(res, diff_params, torch.ones_like(res), allow_unused=True) | 
 |                 expected_grads.append(out_grads) | 
 |                 expected_res += res | 
 |             expected_grads = [torch.stack(grad) for grad in zip(*expected_grads)] | 
 |             if not batch_first: | 
 |                 expected_grads[-1] = expected_grads[-1].transpose(0, 1) | 
 |         self.assertEqual(actual_res, expected_res) | 
 |         [self.assertEqual(actual, expected, atol=atol, rtol=rtol) for (actual, expected) in zip(actual_grads, expected_grads)] | 
 |  | 
 |     def _do_test_multi_input(self, module, input): | 
 |         class TestModule(nn.Module): | 
 |             def __init__(self, module): | 
 |                 super().__init__() | 
 |                 self.module = module | 
 |  | 
 |             def forward(self, input): | 
 |                 return self.module(input) + self.module(input) | 
 |  | 
 |         batch_size = input.shape[0] | 
 |         diff_input = input.dtype == torch.float or input.dtype == torch.double | 
 |         if diff_input: | 
 |             input.requires_grad_() | 
 |         with freeze_rng_state(): | 
 |             # get per sample grads with ExpandedWeights context manager, calling .backward() twice | 
 |             test_module = TestModule(module) | 
 |             actual_res = call_for_per_sample_grads(test_module, loss_reduction="sum")(input).sum() | 
 |             actual_res.backward() | 
 |             actual_grads = [] | 
 |             for param in module.parameters(): | 
 |                 actual_grads.append(param.grad_sample) | 
 |                 del param.grad_sample | 
 |             if diff_input: | 
 |                 actual_grads.append(input.grad.clone()) | 
 |                 input.grad = torch.zeros_like(input.grad) | 
 |  | 
 |  | 
 |             # get per sample grads with a for loop, running over the input twice | 
 |             expected_grads = [] | 
 |             for i in range(batch_size): | 
 |                 input_slice = input[i] | 
 |                 diff_params = module.parameters() | 
 |                 if diff_input: | 
 |                     diff_params = chain(diff_params, (input_slice,)) | 
 |                 res = module(input_slice.unsqueeze(0)).sum() | 
 |                 out_grads = torch.autograd.grad(res, diff_params, torch.ones_like(res), allow_unused=True) | 
 |                 expected_grads.append(out_grads) | 
 |         expected_grads = tuple(torch.stack(grad) for grad in zip(*expected_grads)) | 
 |         expected_grads = tuple(expected_grad for expected_grad in expected_grads if expected_grad is not None) | 
 |         assert [self.assertEqual(actual, 2 * expected) for (actual, expected) in zip(actual_grads, expected_grads)] | 
 |  | 
 |     def _do_test_rnn_packed_sequence(self, module, input, args=None, kwargs=None, atol=None, rtol=None): | 
 |         args = args if args is not None else () | 
 |         kwargs = kwargs if kwargs is not None else {} | 
 |  | 
 |         batch_size = max(tuple(input.batch_sizes)).item() | 
 |  | 
 |         with freeze_rng_state(): | 
 |             # get per sample grads with ExpandedWeights context manager | 
 |             actual_res = call_for_per_sample_grads(module, | 
 |                                                    batch_size=batch_size, | 
 |                                                    loss_reduction="sum")(input, *args, **kwargs).data.sum() | 
 |             actual_res.backward() | 
 |             actual_grads = [] | 
 |             for param in module.parameters(): | 
 |                 self.assertEqual(param.grad_sample.shape[0], batch_size) | 
 |                 actual_grads.append(param.grad_sample) | 
 |                 del param.grad_sample | 
 |  | 
 |             input.data.grad = torch.zeros_like(input.data) | 
 |  | 
 |             # compute the per sample grads with a for loop | 
 |             expected_res = torch.zeros_like(actual_res) | 
 |             expected_grads = [] | 
 |             padded_input, seq_sizes = torch.nn.utils.rnn.pad_packed_sequence(input, batch_first=True) | 
 |             for i in range(len(seq_sizes)): | 
 |                 input_slice = padded_input[i].narrow(0, 0, seq_sizes[i]) | 
 |                 diff_params = module.parameters() | 
 |                 batch_dim = 0 if module.m.batch_first else 1 | 
 |                 res = module(input_slice.unsqueeze(batch_dim), *args, **kwargs).sum() | 
 |                 expected_res += res | 
 |                 out_grads = torch.autograd.grad(res, diff_params, torch.ones_like(res), allow_unused=True) | 
 |                 expected_grads.append(out_grads) | 
 |  | 
 |             expected_grads = [torch.stack(grad) for grad in zip(*expected_grads)] | 
 |             self.assertEqual(actual_res, expected_res) | 
 |             [self.assertEqual(actual, expected, atol=atol, rtol=rtol) for (actual, expected) in zip(actual_grads, expected_grads)] | 
 |  | 
 |     @modules(filter(lambda m_info: m_info.module_cls in (torch.nn.RNN, torch.nn.LSTM, torch.nn.GRU), module_db)) | 
 |     @tf32_off() | 
 |     def test_module(self, device, dtype, module_info, training): | 
 |         class RNNWrapper(torch.nn.Module): | 
 |             def __init__(self, m_cons, args, kwargs): | 
 |                 super().__init__() | 
 |                 self.m = m_cons(*args, **kwargs) | 
 |  | 
 |             def forward(self, *inps): | 
 |                 ret = self.m(*inps) | 
 |                 assert isinstance(ret, tuple) | 
 |                 return ret[0] | 
 |  | 
 |         def batch_hidden(h): | 
 |             new_h_shape = [1] * (len(h.shape) + 1) | 
 |             new_h_shape[1] = 2 | 
 |             return h.unsqueeze(1).repeat(new_h_shape) | 
 |  | 
 |  | 
 |         module_cls = module_info.module_cls | 
 |         atol, rtol = (1e-4, 1e-5) if module_cls == torch.nn.GRU and dtype == torch.float32 else (None, None) | 
 |         module_inputs = module_info.module_inputs_func(module_info, device=device, dtype=dtype, | 
 |                                                        requires_grad=True, training=training, with_packed_sequence=True) | 
 |         for module_input in module_inputs: | 
 |             if module_input.forward_input is None: | 
 |                 continue | 
 |             args, kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs | 
 |             m = RNNWrapper(module_cls, args, kwargs) | 
 |             batch_first = m.m.batch_first | 
 |             m.to(device).to(dtype) | 
 |  | 
 |             args, kwargs = module_input.forward_input.args, module_input.forward_input.kwargs | 
 |  | 
 |             # if the RNN tests use unbatched inputs--batch the inputs | 
 |             input = args[0] | 
 |             if isinstance(input, torch.Tensor) and input.dim() == 2: | 
 |                 input = input.detach() | 
 |                 new_input_shape = [1] * (len(input.shape) + 1) | 
 |                 if batch_first: | 
 |                     new_input_shape[0] = 2 | 
 |                     input = input.repeat(new_input_shape) | 
 |                 else: | 
 |                     new_input_shape[1] = 2 | 
 |                     input = input.unsqueeze(1).repeat(new_input_shape) | 
 |  | 
 |                 h = args[1] if len(args) > 1 else None | 
 |                 if h is not None: | 
 |                     h = batch_hidden(h) if isinstance(h, torch.Tensor) else tuple(batch_hidden(hx) for hx in h) | 
 |                     args = list(args) | 
 |                     args[1] = h | 
 |  | 
 |             if isinstance(input, torch.nn.utils.rnn.PackedSequence): | 
 |                 self._do_test_rnn_packed_sequence(m, input, args[1:], kwargs, atol=atol, rtol=rtol) | 
 |             else: | 
 |                 self._do_test(m, input, args[1:], kwargs, batch_first=batch_first, atol=atol, rtol=rtol) | 
 |  | 
 |     def test_per_sample_api_failing(self): | 
 |         module = nn.Linear(10, 10) | 
 |         input = torch.randn(64, 10) | 
 |         with self.assertRaisesRegex(RuntimeError, r"Module passed must be nn.Module"): | 
 |             call_for_per_sample_grads("fail")(input) | 
 |         with self.assertRaisesRegex(RuntimeError, r"Batch size passed must be None or an integer"): | 
 |             call_for_per_sample_grads(module, batch_size=6.4)(input) | 
 |         with self.assertRaisesRegex(RuntimeError, r"Batch size must be positive"): | 
 |             call_for_per_sample_grads(module, batch_size=-64)(input) | 
 |         with self.assertRaisesRegex(RuntimeError, r"incorrect for multiple calls"): | 
 |             loss = call_for_per_sample_grads(module)(input).sum() | 
 |             loss.backward()  # populate grad_sample fields | 
 |             call_for_per_sample_grads(module)(input) | 
 |  | 
 |         module = nn.Linear(10, 10)  # reset to not have grad_sample fields | 
 |         with self.assertRaisesRegex(RuntimeError, r"Expected loss_reduction argument to be sum or mean"): | 
 |             call_for_per_sample_grads(module, loss_reduction="")(input) | 
 |  | 
 |     def test_per_sample_api_compute_batch_size(self): | 
 |         class CustomModule(nn.Module): | 
 |             def __init__(self): | 
 |                 super().__init__() | 
 |                 self.linear = nn.Linear(5, 5) | 
 |  | 
 |             def forward(self, input1, input2): | 
 |                 return self.linear(input1) + self.linear(input2) | 
 |  | 
 |         module = CustomModule() | 
 |         input1 = torch.randn(4, 5) | 
 |         input2 = torch.randn(5, 5) | 
 |  | 
 |         with self.assertRaisesRegex(RuntimeError, "found at least one input with batch size 4 and one with batch size 5"): | 
 |             call_for_per_sample_grads(module)(input1, input2) | 
 |  | 
 |         input2 = torch.randn(4, 5) | 
 |         call_for_per_sample_grads(module)(input1, input2) | 
 |  | 
 |         module = CustomModule() | 
 |         call_for_per_sample_grads(module)(input1, input2=input2) | 
 |  | 
 |         module = CustomModule() | 
 |         call_for_per_sample_grads(module)(input1=input1, input2=input2) | 
 |  | 
 |     def test_per_sample_api_compute_batch_size_not_pytreeable(self): | 
 |         @dataclass | 
 |         class NonPytreeableTuple: | 
 |             elem1: torch.Tensor | 
 |             elem2: torch.Tensor | 
 |  | 
 |         class CustomModule(nn.Module): | 
 |             def __init__(self): | 
 |                 super().__init__() | 
 |                 self.linear = nn.Linear(5, 5) | 
 |  | 
 |             def forward(self, input1, input2): | 
 |                 return self.linear(input1.elem1) + self.linear(input1.elem2) | 
 |  | 
 |         input = NonPytreeableTuple(torch.randn(4, 5), torch.randn(4, 5)) | 
 |         model = CustomModule() | 
 |         with self.assertRaisesRegex(RuntimeError, "ExpandedWeights cannot compute the batch size from the inputs"): | 
 |             call_for_per_sample_grads(model)(input, "") | 
 |  | 
 |         # would prefer for it to error because input is not pytree-able but that's hard to detect | 
 |         with self.assertRaisesRegex(RuntimeError, "Expected ExpandedWeights to have batch size matching input"): | 
 |             call_for_per_sample_grads(model)(input, torch.randn(5)) | 
 |  | 
 |         model = CustomModule()  # TODO: functional call bug, sam will fix | 
 |         call_for_per_sample_grads(model)(input, torch.randn(4, 5)) | 
 |         model = CustomModule() | 
 |         call_for_per_sample_grads(model, batch_size=4)(input, torch.randn(5)) | 
 |  | 
 | class ContextManagerTests(TestBase): | 
 |     def __init__(self, *args, **kwargs): | 
 |         self.test_cpu = kwargs.get('test_cpu', True) | 
 |         self.test_cuda = kwargs.get('test_cuda', True) | 
 |         super().__init__(*args, **kwargs) | 
 |  | 
 |     @property | 
 |     def constructor_args(self): | 
 |         return self._get_arg('constructor_args', False) | 
 |  | 
 |     def test_context_manager(self, test_case, device): | 
 |         kwargs = {'device': device, 'dtype': torch.double} | 
 |         module = self.constructor(*self.constructor_args).to(**kwargs) | 
 |         if 'Embedding' in self.get_name(): | 
 |             kwargs['dtype'] = torch.long | 
 |         input = self._get_input().to(**kwargs) | 
 |         if len(input.shape) == 0 or input.shape[0] == 0: | 
 |             raise unittest.SkipTest("Can't get per sample gradients when no batch dim or batch dim is 0") | 
 |         if self.constructor == torch.nn.Linear and len(input.shape) == 1: | 
 |             raise unittest.SkipTest("Can't get per sample gradients for input of rank 1") | 
 |         test_case._do_test(module, input) | 
 |  | 
 |     def test_context_manager_multiple_inputs(self, test_case, device): | 
 |         module = self.constructor(*self.constructor_args).to(device) | 
 |         input = self._get_input() | 
 |         if len(input.shape) == 0 or input.shape[0] == 0: | 
 |             raise unittest.SkipTest("Can't get per sample gradients when no batch dim or batch dim is 0") | 
 |         if self.constructor == torch.nn.Linear and len(input.shape) == 1: | 
 |             raise unittest.SkipTest("Can't get per sample gradients for input of rank 1") | 
 |         test_case._do_test_multi_input(module, input) | 
 |  | 
 | def filter_supported_tests(t): | 
 |     supported_modules = ['Linear', 'Conv1d', 'Conv2d', 'Conv3d', 'Embedding', 'LayerNorm', 'GroupNorm', 'InstanceNorm'] | 
 |     if 'module_name' in t and t['module_name'] in supported_modules: | 
 |         return True | 
 |  | 
 | # TODO: Once all of these use ModuleInfo, replace with ModuleInfo tests | 
 | # These currently use the legacy nn tests | 
 | supported_tests = [t for t in module_tests + new_module_tests if filter_supported_tests(t)] | 
 | for test_param in supported_tests: | 
 |     if 'constructor' not in test_param: | 
 |         name = test_param.pop('module_name') | 
 |         test_param['constructor'] = getattr(nn, name) | 
 |     decorator = test_param.pop('decorator', lambda test: test) | 
 |     test = ContextManagerTests(**test_param) | 
 |     test_name = test.get_name() | 
 |     if hasattr(TestExpandedWeightModule, test_name): | 
 |         raise RuntimeError('Found two tests with the same name: ' + test_name) | 
 |     test_name_multi_input = test.get_name() + "_multiple_inputs" | 
 |     if hasattr(TestExpandedWeightModule, test_name_multi_input): | 
 |         raise RuntimeError('Found two tests with the same name: ' + test_name) | 
 |     if test.test_cpu: | 
 |         setattr(TestExpandedWeightModule, test_name, decorator(lambda self, test=test: test.test_context_manager(self, 'cpu'))) | 
 |         setattr(TestExpandedWeightModule, test_name_multi_input, | 
 |                 decorator(lambda self, test=test: test.test_context_manager_multiple_inputs(self, 'cpu'))) | 
 |     if TEST_CUDA and test.test_cuda: | 
 |         # since this checks derivatives, only use double for precision | 
 |         setattr(TestExpandedWeightModule, test_name + '_cuda_double', | 
 |                 decorator(lambda self, test=test: test.test_context_manager(self, 'cuda'))) | 
 |  | 
 | # ------------- HELPER FUNCTIONS ----------------- | 
 |  | 
 | def run_op(op, input, *args, **kwargs): | 
 |     r""" | 
 |     OpInfo for Embedding switches the input and weight so autograd tests will only check the derivative | 
 |     of the weight, not the input, which can't be differentiable since its dtype is int. Calls op, | 
 |     using the special ordering that Embedding's OpInfo expects for that case. | 
 |     """ | 
 |     if op.name == "nn.functional.embedding": | 
 |         return op(args[0], input, **kwargs) | 
 |     else: | 
 |         return op(input, *args, **kwargs) | 
 |  | 
 | def make_expanded_weight(sample_input, batch_size, loss_reduction="sum"): | 
 |     def expanded_weight_or_clone(arg): | 
 |         if is_diff_tensor(arg): | 
 |             return ExpandedWeight(torch.clone(arg), batch_size, loss_reduction) | 
 |         return clone_if_tensor(arg) | 
 |  | 
 |     ew_input = clone_if_tensor(sample_input.input) | 
 |     ew_args = tuple(expanded_weight_or_clone(arg) for arg in sample_input.args) | 
 |     ew_kwargs = {name: expanded_weight_or_clone(arg) for (name, arg) in sample_input.kwargs.items()} | 
 |     return ew_input, ew_args, ew_kwargs | 
 |  | 
 | def supported_inputs(op, sample_inputs, supported_inputs=True): | 
 |     r""" | 
 |     ExpandedWeights currently does not support some use cases when there's no batch dimension or | 
 |     operations that would cause inter-batch operations. Removes all of the cases it cannot deal with | 
 |     """ | 
 |     def filter_fn(input): | 
 |         convolutions = ["nn.functional.conv1d", "nn.functional.conv2d", "nn.functional.conv3d"] | 
 |         batched_input_size = dict(zip(convolutions, [3, 4, 5])) | 
 |         if op.name == "nn.functional.linear": | 
 |             is_supported_input = input.input.dim() > 1  # input of rank 1 means no batch dim | 
 |         elif op.name == "nn.functional.layer_norm": | 
 |             normalized_shape = input.args[0] | 
 |             is_supported_input = input.input.shape != normalized_shape  # would cause inter-batch operations | 
 |         elif op.name in convolutions: | 
 |             # currently can't deal with padding computation on Python level | 
 |             is_supported_input = input.input.dim() == batched_input_size[op.name] | 
 |         elif op.name == "nn.functional.embedding": | 
 |             idx = input.args[0] | 
 |             is_supported_input = len(idx.shape) > 1  # there's no batch size | 
 |         else: | 
 |             is_supported_input = True | 
 |         is_supported_input = is_supported_input and input.input.shape[0] > 0  # 0 is not a valid batch size | 
 |         return is_supported_input if supported_inputs else not is_supported_input | 
 |     return [input for input in sample_inputs if filter_fn(input)] | 
 |  | 
 | def for_loop_per_sample_grad(batch_size, reduction, input, func, *args, **kwargs): | 
 |     # get per sample grads by getting derivative for each input in a for loop | 
 |     per_sample_grad = [] | 
 |     for i in range(batch_size): | 
 |         per_sample_input = input[i] | 
 |         result = reduction(func(per_sample_input.unsqueeze(0), *args, **kwargs)) | 
 |         diff_input_list = (per_sample_input,) + tuple(args) + tuple(kwargs.values()) | 
 |         diff_input_list = [i for i in diff_input_list if isinstance(i, torch.Tensor) and i.requires_grad] | 
 |         per_sample_grad.append(torch.autograd.grad(result, diff_input_list, torch.ones_like(result), allow_unused=True)) | 
 |     if len(per_sample_grad) == batch_size: | 
 |         per_sample_grad = tuple(torch.stack(grad) for grad in zip(*per_sample_grad)) | 
 |     return per_sample_grad | 
 |  | 
 | def is_diff_tensor(t): | 
 |     return isinstance(t, ExpandedWeight) or (isinstance(t, torch.Tensor) and t.requires_grad) | 
 |  | 
 | def clone_if_tensor(t): | 
 |     if isinstance(t, torch.Tensor): | 
 |         res = torch.clone(t).detach() | 
 |         res.requires_grad_(t.requires_grad) | 
 |         return res | 
 |     else: | 
 |         return t | 
 |  | 
 | instantiate_device_type_tests(TestExpandedWeightHelperFunction, globals()) | 
 | instantiate_device_type_tests(TestExpandedWeightFunctional, globals()) | 
 | instantiate_device_type_tests(TestExpandedWeightModule, globals()) | 
 | if __name__ == '__main__': | 
 |     run_tests() |