Fix assertNotEqual error reporting (#39217)
Summary:
`msg` argument must be passed to `assertRaises`, because its exception is passed upstream (with custom error message) if `assertEquals` succeedes.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/39217
Differential Revision: D21786141
Pulled By: malfet
fbshipit-source-id: f8c3d4f30f474fe269e50252a06eade76d575a68
diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py
index 4b1fcc4..8a28dd7 100644
--- a/torch/testing/_internal/common_utils.py
+++ b/torch/testing/_internal/common_utils.py
@@ -1109,8 +1109,8 @@
self.assertEqual(x, y, msg=msg, atol=prec, rtol=rtol)
def assertNotEqual(self, x, y, *, msg=None, atol=None, rtol=None):
- with self.assertRaises(AssertionError):
- self.assertEqual(x, y, msg=msg, atol=atol, rtol=rtol)
+ with self.assertRaises(AssertionError, msg=msg):
+ self.assertEqual(x, y, atol=atol, rtol=rtol)
def assertEqualTypeString(self, x, y):
# This API is used simulate deprecated x.type() == y.type()