Change `test_conv_large` parameter initialization (#71521)
Summary:
This PR twiddles the parameters of the conv layer in `test_conv_large` to better avoid NaN values. Previously, this test would cause a NaN to be computed for `scale` (propagated from `.mean()` on the `.grad` tensor). This NaN would then be propagated to the scaled gradients via division, resulting in a bogus `assertEqual` check as `NaN == NaN` is by default true. (This behavior was observed on V100 and A100).
To improve visibility of failures in the event of NaNs in `grad1`, scale is now computed from `grad2`.
Interestingly enough, we discovered this issue when trying out some less common setups that broke this test; it turns out those breakages were cases where there were no NaN values (leading to an actual `assertEqual` check that would fail for `float16`).
CC ptrblck ngimel puririshi98
Pull Request resolved: https://github.com/pytorch/pytorch/pull/71521
Reviewed By: anjali411
Differential Revision: D33776705
Pulled By: ngimel
fbshipit-source-id: a1ec4792cba04c6322b22ef5b80ce08579ea4cf6
(cherry picked from commit d207bd9b87f8e8c2cb13182b7295c17e19dc3dba)
diff --git a/test/test_nn.py b/test/test_nn.py
index 8c8fe85..987c585 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -16336,6 +16336,7 @@
def test_conv_large(self, device):
dtype = torch.half if self.device_type == 'cuda' else torch.float
conv = nn.Conv2d(2, 2, 8, 8, bias=False).to(device).to(dtype)
+ conv.weight = torch.nn.Parameter(torch.randn(2, 2, 8, 8, device=device, dtype=dtype) / 64)
input_large = torch.randn(4097, 2, 512, 512, dtype=dtype, device=device)
# forward
ret = conv(input_large)
@@ -16358,10 +16359,10 @@
grad2 = conv.weight.grad.detach().clone()
# gradients are at the order of hundreds, we need to scale it to
# the order of one so that we can compare
- scale = 1 / grad1.abs().mean()
+ scale = 1 / grad2.abs().mean()
grad1 = grad1 * scale
grad2 = grad2 * scale
- self.assertEqual(grad1, grad2)
+ self.assertEqual(grad1, grad2, atol=5e-2, rtol=5e-3)
def _test_gumbel_softmax_st_shapes(self, device, dtype, shape, dim, count_expected):
logits = torch.randn(shape, dtype=torch.float, device=device)