Use unittest assertWarns instead (#36411)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36411
This PR remove pytorch specific defined assertwarns and use the unit
test one, also format some tests
Test Plan: Imported from OSS
Differential Revision: D20998159
Pulled By: wanchaol
fbshipit-source-id: 1280ecff2dd293b95a639d13cc7417fc819c2201
diff --git a/test/distributed/test_data_parallel.py b/test/distributed/test_data_parallel.py
index 5d7f2c5..cdd3fcf 100644
--- a/test/distributed/test_data_parallel.py
+++ b/test/distributed/test_data_parallel.py
@@ -644,9 +644,10 @@
self._testcase = testcase
def forward(self, x):
- self._testcase.assertWarnsRegex(
- lambda: self.zero_grad(),
- r"Calling \.zero_grad\(\) from a module that was passed to a nn\.DataParallel\(\) has no effect.")
+ with self._testcase.assertWarnsRegex(
+ UserWarning,
+ r"Calling \.zero_grad\(\) from a module that was passed to a nn\.DataParallel\(\) has no effect."):
+ self.zero_grad()
return x
module = Net(self).cuda()
diff --git a/test/test_dataloader.py b/test/test_dataloader.py
index 2918411..c18f12d 100644
--- a/test/test_dataloader.py
+++ b/test/test_dataloader.py
@@ -320,6 +320,8 @@
set_faulthander_if_available()
# Process `pid` must have called `set_faulthander_if_available`
+
+
def print_traces_of_all_threads(pid):
if HAS_FAULTHANDLER:
if not IS_WINDOWS:
@@ -902,7 +904,6 @@
with self.assertRaisesRegex(ValueError, "timeout option should be non-negative"):
DataLoader(self.dataset, timeout=-1)
-
# disable auto-batching
with self.assertRaisesRegex(ValueError,
"batch_size=None option disables auto-batching and is mutually exclusive"):
@@ -1024,10 +1025,11 @@
for _ in range(20):
self.assertNotWarn(lambda: next(it), "Should not warn before exceeding length")
for _ in range(3):
- self.assertWarnsRegex(
- lambda: next(it),
+ with self.assertWarnsRegex(
+ UserWarning,
r"but [0-9]+ samples have been fetched\. For multiprocessing data-loading, this",
- "Should always warn after exceeding length")
+ msg="Should always warn after exceeding length"):
+ next(it)
# [no auto-batching] test that workers exit gracefully
workers = dataloader_iter._workers
diff --git a/test/test_jit.py b/test/test_jit.py
index 0aeebc1..1486049 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -406,7 +406,7 @@
def __init__(self, cpu_device_str):
super(M, self).__init__()
self.p0 = nn.Parameter(torch.tensor([0.3], dtype=torch.float,
- device=cpu_device_str))
+ device=cpu_device_str))
self.b0 = torch.tensor([0.9], dtype=torch.float,
device=cpu_device_str)
@@ -14263,12 +14263,6 @@
torch.testing.assert_allclose(y_int8, y_ref, rtol=0.0001, atol=1e-3)
torch.testing.assert_allclose(y_fp16, y_ref, rtol=0.0001, atol=1e-3)
- def checkTracerWarning(self, *args, **kwargs):
- with warnings.catch_warnings(record=True) as warns:
- torch.jit.trace(*args, **kwargs)
- self.assertGreater(len(warns), 0)
- self.assertTrue(any(["cause the trace to be incorrect" in str(warn.message) for warn in warns]))
-
def test_trace_checker_slice_lhs(self):
def foo(x):
for i in range(3):
@@ -14282,18 +14276,21 @@
x.view(-1).add_(-x.view(-1))
return x
- self.assertWarnsRegex(lambda: torch.jit.trace(foo,
- torch.rand(3, 4),
- check_inputs=[torch.rand(5, 6)],
- _force_outplace=True),
- 'Output nr 1. of the traced function does not match the '
- 'corresponding output of the Python function')
+ with self.assertWarnsRegex(torch.jit.TracerWarning,
+ 'Output nr 1. of the traced function does not match the '
+ 'corresponding output of the Python function'):
+ torch.jit.trace(foo,
+ torch.rand(3, 4),
+ check_inputs=[torch.rand(5, 6)],
+ _force_outplace=True)
def test_lhs_index_fails(self):
def foo(x):
x[0, 1] = 4
return x
- self.checkTracerWarning(foo, torch.rand(3, 4), _force_outplace=True)
+
+ with self.assertWarnsRegex(torch.jit.TracerWarning, "cause the trace to be incorrect"):
+ torch.jit.trace(foo, torch.rand(3, 4), _force_outplace=True)
def test_lhs_index_trivial(self):
def foo(y, x):
@@ -14305,18 +14302,23 @@
def foo(x):
x.view(-1).add_(-x.view(-1))
return x
- self.checkTracerWarning(foo, torch.rand(3, 4), _force_outplace=True)
+
+ with self.assertWarnsRegex(torch.jit.TracerWarning, "cause the trace to be incorrect"):
+ torch.jit.trace(foo, torch.rand(3, 4), _force_outplace=True)
@suppress_warnings
def test_trace_checker_dropout_train(self):
def foo(x):
return torch.dropout(x, p=0.5, train=True)
- self.assertWarnsRegex(lambda: torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(5, 6)]),
- 'Output nr 1. of the traced function does not match the '
- 'corresponding output of the Python function')
- self.assertWarnsRegex(lambda: torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(5, 6)]),
- 'Trace had nondeterministic nodes')
+ with self.assertWarnsRegex(torch.jit.TracerWarning,
+ 'Output nr 1. of the traced function does not match the '
+ 'corresponding output of the Python function'):
+ torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(5, 6)])
+
+ with self.assertWarnsRegex(torch.jit.TracerWarning,
+ 'Trace had nondeterministic nodes'):
+ torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(5, 6)])
def test_trace_checker_dropout_notrain(self):
input = torch.rand(3, 4)
diff --git a/test/test_mkldnn.py b/test/test_mkldnn.py
index 2d21d1c..93ddd1f 100644
--- a/test/test_mkldnn.py
+++ b/test/test_mkldnn.py
@@ -67,10 +67,12 @@
# because MKLDNN only supports float32, we need to lessen the precision.
# these numbers are just empirical results that seem to work.
- self.assertWarnsRegex(lambda: gradcheck(func, [root], atol=4e-2, rtol=1e-2),
- 'double precision floating point')
- self.assertWarnsRegex(lambda: gradgradcheck(func, [root], atol=4e-2, rtol=1e-2),
- 'double precision floating point')
+ self.assertWarnsRegex(UserWarning,
+ 'double precision floating point',
+ lambda: gradcheck(func, [root], atol=4e-2, rtol=1e-2))
+ self.assertWarnsRegex(UserWarning,
+ 'double precision floating point',
+ lambda: gradgradcheck(func, [root], atol=4e-2, rtol=1e-2))
def test_autograd_from_mkldnn(self):
# MKLDNN only supports float32
@@ -81,8 +83,9 @@
# because MKLDNN only supports float32, we need to lessen the precision.
# these numbers are just empirical results that seem to work.
- self.assertWarnsRegex(lambda: gradcheck(func, [root], atol=4e-2, rtol=1e-2),
- 'double precision floating point')
+ self.assertWarnsRegex(UserWarning,
+ 'double precision floating point',
+ lambda: gradcheck(func, [root], atol=4e-2, rtol=1e-2))
def test_detach(self):
root = torch.randn(4, 5, dtype=torch.float32).to_mkldnn().requires_grad_()
diff --git a/test/test_nn.py b/test/test_nn.py
index 6aaa480..1bf4154 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -8368,7 +8368,9 @@
def fn():
init.normal(x)
- self.assertWarnsRegex(fn, 'deprecated', 'methods not suffixed with underscore should be deprecated')
+
+ with self.assertWarnsRegex(UserWarning, 'deprecated', msg='methods not suffixed with underscore should be deprecated'):
+ fn()
class TestFusionEval(TestCase):
@given(X=hu.tensor(shapes=((5, 3, 5, 5),)),
diff --git a/test/test_optim.py b/test/test_optim.py
index 22eeaf7..b0d502b 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -575,7 +575,7 @@
scheduler.step()
self.opt.step()
- self.assertWarnsRegex(old_pattern, r'how-to-adjust-learning-rate')
+ self.assertWarnsRegex(UserWarning, r'how-to-adjust-learning-rate', old_pattern)
def test_old_pattern_warning_with_arg(self):
epochs = 35
@@ -589,7 +589,7 @@
scheduler.step()
self.opt.step()
- self.assertWarnsRegex(old_pattern2, r'how-to-adjust-learning-rate')
+ self.assertWarnsRegex(UserWarning, r'how-to-adjust-learning-rate', old_pattern2)
def test_old_pattern_warning_resuming(self):
epochs = 35
@@ -606,7 +606,7 @@
scheduler.step()
self.opt.step()
- self.assertWarnsRegex(old_pattern, r'how-to-adjust-learning-rate')
+ self.assertWarnsRegex(UserWarning, r'how-to-adjust-learning-rate', old_pattern)
def test_old_pattern_warning_resuming_with_arg(self):
epochs = 35
@@ -623,7 +623,7 @@
scheduler.step()
self.opt.step()
- self.assertWarnsRegex(old_pattern2, r'how-to-adjust-learning-rate')
+ self.assertWarnsRegex(UserWarning, r'how-to-adjust-learning-rate', old_pattern2)
def test_old_pattern_warning_with_overridden_optim_step(self):
epochs = 35
@@ -651,7 +651,7 @@
scheduler.step()
self.opt.step()
- self.assertWarnsRegex(old_pattern2, r'how-to-adjust-learning-rate')
+ self.assertWarnsRegex(UserWarning, r'how-to-adjust-learning-rate', old_pattern2)
def test_new_pattern_no_warning(self):
epochs = 35
@@ -704,7 +704,7 @@
self.opt.step()
scheduler.step()
- self.assertWarnsRegex(new_pattern, r'`optimizer.step\(\)` has been overridden')
+ self.assertWarnsRegex(UserWarning, r'`optimizer.step\(\)` has been overridden', new_pattern)
def _test_lr_is_constant_for_constant_epoch(self, scheduler):
l = []
diff --git a/test/test_torch.py b/test/test_torch.py
index cb35f75..3a1b4e7 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -527,7 +527,7 @@
self.assertEqual(n.shape, t.shape)
if t.dtype == torch.float:
self.assertTrue(np.allclose(n, t.numpy(), rtol=1e-03, atol=1e-05,
- equal_nan=True))
+ equal_nan=True))
else:
self.assertTrue(np.allclose(n, t.numpy(), equal_nan=True))
@@ -549,29 +549,29 @@
dim).cpu(),
expected)
do_one(self._make_tensors((5, 400000), use_floating=use_floating,
- use_integral=use_integral), 1)
+ use_integral=use_integral), 1)
do_one(self._make_tensors((3, 5, 7), use_floating=use_floating,
- use_integral=use_integral), 0)
+ use_integral=use_integral), 0)
do_one(self._make_tensors((3, 5, 7), use_floating=use_floating,
- use_integral=use_integral), 1)
+ use_integral=use_integral), 1)
do_one(self._make_tensors((3, 5, 7), use_floating=use_floating,
- use_integral=use_integral), 2)
+ use_integral=use_integral), 2)
do_one(self._make_tensors((100000, ), use_floating=use_floating,
- use_integral=use_integral), -1)
+ use_integral=use_integral), -1)
do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
- use_integral=use_integral), 0)
+ use_integral=use_integral), 0)
do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
- use_integral=use_integral), 1)
+ use_integral=use_integral), 1)
do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
- use_integral=use_integral), 2)
+ use_integral=use_integral), 2)
do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
- use_integral=use_integral), (1, 2))
+ use_integral=use_integral), (1, 2))
do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
- use_integral=use_integral), (1, -1))
+ use_integral=use_integral), (1, -1))
do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
- use_integral=use_integral), (0, 2))
+ use_integral=use_integral), (0, 2))
do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
- use_integral=use_integral), (0, 2, 1))
+ use_integral=use_integral), (0, 2, 1))
@slowTest
@unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
@@ -1454,9 +1454,8 @@
# Test Rounding Errors
line = torch.zeros(size=(1, 49))
- self.assertWarnsRegex(lambda: torch.arange(-1, 1, 2. / 49, dtype=torch.float32, out=line),
- 'resized',
- 'The out tensor will be resized')
+ self.assertWarnsRegex(UserWarning, 'The out tensor will be resized',
+ lambda: torch.arange(-1, 1, 2. / 49, dtype=torch.float32, out=line))
self.assertEqual(line.shape, [50])
x = torch.empty(1).expand(10)
@@ -2274,7 +2273,7 @@
def test_numpy_non_writeable(self):
arr = np.zeros(5)
arr.flags['WRITEABLE'] = False
- self.assertWarns(lambda: torch.from_numpy(arr))
+ self.assertWarns(UserWarning, lambda: torch.from_numpy(arr))
@unittest.skipIf(not TEST_NUMPY, "Numpy not found")
def test_empty_storage_view(self):
@@ -5165,9 +5164,8 @@
# inputs are non-expandable tensors, but they have same number of elements
# TORCH_WARN_ONCE is used in torch.normal, only 1st assertEqual will show warn msg
if not warned:
- self.assertWarnsRegex(
- lambda: self.assertEqual(torch.normal(tensor120, tensor2345).size(), (120,)),
- "deprecated and the support will be removed")
+ self.assertWarnsRegex(UserWarning, "deprecated and the support will be removed",
+ lambda: self.assertEqual(torch.normal(tensor120, tensor2345).size(), (120,)))
warned = True
else:
self.assertEqual(torch.normal(tensor120, tensor2345).size(), (120,))
@@ -7836,19 +7834,19 @@
[1, 2]])
columns = [0],
self.assertEqual(reference[rows, columns], torch.tensor([[1, 1],
- [3, 5]], dtype=dtype, device=device))
+ [3, 5]], dtype=dtype, device=device))
rows = ri([[0, 0],
[1, 2]])
columns = ri([1, 0])
self.assertEqual(reference[rows, columns], torch.tensor([[2, 1],
- [4, 5]], dtype=dtype, device=device))
+ [4, 5]], dtype=dtype, device=device))
rows = ri([[0, 0],
[1, 2]])
columns = ri([[0, 1],
[1, 0]])
self.assertEqual(reference[rows, columns], torch.tensor([[1, 2],
- [4, 5]], dtype=dtype, device=device))
+ [4, 5]], dtype=dtype, device=device))
# setting values
reference[ri([0]), ri([1])] = -1
@@ -10393,8 +10391,8 @@
self.assertEqual(res1, res2)
a = torch.tensor([[True, False, True],
- [False, False, False],
- [True, True, True]], dtype=torch.bool, device=device)
+ [False, False, False],
+ [True, True, True]], dtype=torch.bool, device=device)
b = a.byte()
aRes = torch.cumprod(a, 0)
bRes = torch.cumprod(b, 0)
@@ -10444,8 +10442,8 @@
self.assertEqual(out1[1], indices2)
a = torch.tensor([[True, False, True],
- [False, False, False],
- [True, True, True]], dtype=torch.bool, device=device)
+ [False, False, False],
+ [True, True, True]], dtype=torch.bool, device=device)
b = a.byte()
aRes = op(a, 0)
bRes = op(b, 0)
@@ -12674,7 +12672,7 @@
self.assertEqual(sort_topk, a[topk[1]]) # check indices
@dtypesIfCUDA(*([torch.half, torch.float, torch.double]
- + ([torch.bfloat16] if TEST_WITH_ROCM else [])))
+ + ([torch.bfloat16] if TEST_WITH_ROCM else [])))
@dtypes(torch.float, torch.double)
def test_topk_nonfinite(self, device, dtype):
x = torch.tensor([float('nan'), float('inf'), 1e4, 0, -1e4, -float('inf')], device=device, dtype=dtype)
@@ -13073,7 +13071,7 @@
self.assertEqual(input_values.erf().erfinv(), input_values)
# test inf
self.assertTrue(torch.equal(torch.tensor([-1, 1], dtype=dtype, device=device).erfinv(),
- torch.tensor([-inf, inf], dtype=dtype, device=device)))
+ torch.tensor([-inf, inf], dtype=dtype, device=device)))
# test nan
self.assertEqual(torch.tensor([-2, 2], dtype=dtype, device=device).erfinv(),
torch.tensor([nan, nan], dtype=dtype, device=device))
@@ -14976,15 +14974,11 @@
if to_ > from_:
if not (min_val <= from_ <= max_val) or not (min_val <= (to_ - 1) <= max_val):
if not (min_val <= from_ <= max_val):
- self.assertWarnsRegex(
- lambda: t.random_(from_, to_),
- "from is out of bounds"
- )
+ self.assertWarnsRegex(UserWarning, "from is out of bounds",
+ lambda: t.random_(from_, to_))
if not (min_val <= (to_ - 1) <= max_val):
- self.assertWarnsRegex(
- lambda: t.random_(from_, to_),
- "to - 1 is out of bounds"
- )
+ self.assertWarnsRegex(UserWarning, "to - 1 is out of bounds",
+ lambda: t.random_(from_, to_))
else:
t.random_(from_, to_)
range_ = to_ - from_
@@ -15078,15 +15072,11 @@
if to_ > from_:
if not (min_val <= from_ <= max_val) or not (min_val <= (to_ - 1) <= max_val):
if not (min_val <= from_ <= max_val):
- self.assertWarnsRegex(
- lambda: t.random_(from_, to_),
- "from is out of bounds"
- )
+ self.assertWarnsRegex(UserWarning, "from is out of bounds",
+ lambda: t.random_(from_, to_))
if not (min_val <= (to_ - 1) <= max_val):
- self.assertWarnsRegex(
- lambda: t.random_(from_, to_),
- "to - 1 is out of bounds"
- )
+ self.assertWarnsRegex(UserWarning, "to - 1 is out of bounds",
+ lambda: t.random_(from_, to_))
else:
t.random_(from_, to_)
range_ = to_ - from_
@@ -15147,10 +15137,8 @@
t = torch.empty(size, dtype=dtype, device=device)
if to_ > from_:
if not (min_val <= (to_ - 1) <= max_val):
- self.assertWarnsRegex(
- lambda: t.random_(to_),
- "to - 1 is out of bounds"
- )
+ self.assertWarnsRegex(UserWarning, "to - 1 is out of bounds",
+ lambda: t.random_(to_))
else:
t.random_(to_)
range_ = to_ - from_
@@ -16313,7 +16301,7 @@
def _small_3d_unique(dtype, device):
return (torch.randperm(_S * _S * _S,
- dtype=_convert_t(dtype, device), device=device) + 1).view(_S, _S, _S)
+ dtype=_convert_t(dtype, device), device=device) + 1).view(_S, _S, _S)
def _medium_1d(dtype, device):
return _make_tensor((_M,), dtype, device)
@@ -16874,14 +16862,14 @@
_TorchMathTestMeta('round'),
_TorchMathTestMeta('lgamma', reffn='gammaln', ref_backend='scipy'),
_TorchMathTestMeta('polygamma', args=[0], substr='_0', reffn='polygamma',
- refargs=lambda x: (0, x.numpy()), input_fn=_generate_gamma_input, inputargs=[False],
- ref_backend='scipy'),
+ refargs=lambda x: (0, x.numpy()), input_fn=_generate_gamma_input, inputargs=[False],
+ ref_backend='scipy'),
_TorchMathTestMeta('polygamma', args=[1], substr='_1', reffn='polygamma',
- refargs=lambda x: (1, x.numpy()), input_fn=_generate_gamma_input, inputargs=[False],
- ref_backend='scipy', rtol=0.0008, atol=1e-5),
+ refargs=lambda x: (1, x.numpy()), input_fn=_generate_gamma_input, inputargs=[False],
+ ref_backend='scipy', rtol=0.0008, atol=1e-5),
_TorchMathTestMeta('digamma',
- input_fn=_generate_gamma_input, inputargs=[True], ref_backend='scipy',
- replace_inf_with_nan=True)]
+ input_fn=_generate_gamma_input, inputargs=[True], ref_backend='scipy',
+ replace_inf_with_nan=True)]
def generate_torch_test_functions(cls, testmeta, inplace):
diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py
index 0c5bd26..3b30990 100644
--- a/torch/testing/_internal/common_utils.py
+++ b/torch/testing/_internal/common_utils.py
@@ -1007,27 +1007,6 @@
callable()
self.assertTrue(len(ws) == 0, msg)
- def assertWarns(self, callable, msg=''):
- r"""
- Test if :attr:`callable` raises a warning.
- """
- with self._reset_warning_registry(), warnings.catch_warnings(record=True) as ws:
- warnings.simplefilter("always") # allow any warning to be raised
- callable()
- self.assertTrue(len(ws) > 0, msg)
-
- def assertWarnsRegex(self, callable, regex, msg=''):
- r"""
- Test if :attr:`callable` raises any warning with message that contains
- the regex pattern :attr:`regex`.
- """
- with self._reset_warning_registry(), warnings.catch_warnings(record=True) as ws:
- warnings.simplefilter("always") # allow any warning to be raised
- callable()
- self.assertTrue(len(ws) > 0, msg)
- found = any(re.search(regex, str(w.message)) is not None for w in ws)
- self.assertTrue(found, msg)
-
@contextmanager
def maybeWarnsRegex(self, category, regex=''):
"""Context manager for code that *may* warn, e.g. ``TORCH_WARN_ONCE``.