Change `test_conv_large` parameter initialization (#71521)

Summary:
This PR twiddles the parameters of the conv layer in `test_conv_large` to better avoid NaN values. Previously, this test would cause a NaN to be computed for `scale` (propagated from `.mean()` on the `.grad` tensor). This NaN would then be propagated to the scaled gradients via division, resulting in a bogus `assertEqual` check as `NaN == NaN` is by default true. (This behavior was observed on V100 and A100).
To improve visibility of failures in the event of NaNs in `grad1`, scale is now computed from `grad2`.

Interestingly enough, we discovered this issue when trying out some less common setups that broke this test; it turns out those breakages were cases where there were no NaN values (leading to an actual `assertEqual` check that would fail for `float16`).

CC ptrblck ngimel puririshi98

Pull Request resolved: https://github.com/pytorch/pytorch/pull/71521

Reviewed By: anjali411

Differential Revision: D33776705

Pulled By: ngimel

fbshipit-source-id: a1ec4792cba04c6322b22ef5b80ce08579ea4cf6
(cherry picked from commit d207bd9b87f8e8c2cb13182b7295c17e19dc3dba)
diff --git a/test/test_nn.py b/test/test_nn.py
index 8c8fe85..987c585 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -16336,6 +16336,7 @@
     def test_conv_large(self, device):
         dtype = torch.half if self.device_type == 'cuda' else torch.float
         conv = nn.Conv2d(2, 2, 8, 8, bias=False).to(device).to(dtype)
+        conv.weight = torch.nn.Parameter(torch.randn(2, 2, 8, 8, device=device, dtype=dtype) / 64)
         input_large = torch.randn(4097, 2, 512, 512, dtype=dtype, device=device)
         # forward
         ret = conv(input_large)
@@ -16358,10 +16359,10 @@
         grad2 = conv.weight.grad.detach().clone()
         # gradients are at the order of hundreds, we need to scale it to
         # the order of one so that we can compare
-        scale = 1 / grad1.abs().mean()
+        scale = 1 / grad2.abs().mean()
         grad1 = grad1 * scale
         grad2 = grad2 * scale
-        self.assertEqual(grad1, grad2)
+        self.assertEqual(grad1, grad2, atol=5e-2, rtol=5e-3)
 
     def _test_gumbel_softmax_st_shapes(self, device, dtype, shape, dim, count_expected):
         logits = torch.randn(shape, dtype=torch.float, device=device)