Enable norm gradgradchecks by lowering precision requirements.
diff --git a/test/test_autograd.py b/test/test_autograd.py
index b288e10..f8b4cdc 100644
--- a/test/test_autograd.py
+++ b/test/test_autograd.py
@@ -1911,9 +1911,12 @@
))
-gradgradcheck_exclude_classes = set((
- 'Norm',
-))
+# these are just empirical observations, we should improve
+gradgradcheck_precision_override = {
+ 'test_NormFunction_1_5': {'atol': 1e-2, 'rtol': 1e-2},
+ 'test_NormFunction_2': {'atol': 1e-2, 'rtol': 1e-2},
+ 'test_NormFunction_3': {'atol': 5e-2, 'rtol': 1e-2},
+}
for test in function_tests:
cls, constructor_args, call_args = test[:3]
@@ -1968,14 +1971,18 @@
self.assertTrue(type(inp.data) == type(inp.grad.data))
self.assertTrue(inp.size() == inp.grad.size())
- if cls.__name__ not in gradgradcheck_exclude_classes:
- dummy_out = apply_fn(*input)
- if isinstance(dummy_out, tuple):
- grad_y = tuple(Variable(torch.randn(x.size()), requires_grad=x.requires_grad)
- for x in dummy_out if isinstance(x, Variable))
- else:
- grad_y = (Variable(torch.randn(dummy_out.size()), requires_grad=dummy_out.requires_grad),)
+ dummy_out = apply_fn(*input)
+ if isinstance(dummy_out, tuple):
+ grad_y = tuple(Variable(torch.randn(x.size()), requires_grad=x.requires_grad)
+ for x in dummy_out if isinstance(x, Variable))
+ else:
+ grad_y = (Variable(torch.randn(dummy_out.size()), requires_grad=dummy_out.requires_grad),)
+ if test_name in gradgradcheck_precision_override:
+ atol = gradgradcheck_precision_override[test_name]['atol']
+ rtol = gradgradcheck_precision_override[test_name]['rtol']
+ self.assertTrue(gradgradcheck(apply_fn, input, grad_y, atol=atol, rtol=rtol))
+ else:
self.assertTrue(gradgradcheck(apply_fn, input, grad_y,))
# can't broadcast inplace to left hand side