Use unittest assertWarns instead (#36411)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/36411

This PR remove pytorch specific defined assertwarns and use the unit
test one, also format some tests

Test Plan: Imported from OSS

Differential Revision: D20998159

Pulled By: wanchaol

fbshipit-source-id: 1280ecff2dd293b95a639d13cc7417fc819c2201
diff --git a/test/distributed/test_data_parallel.py b/test/distributed/test_data_parallel.py
index 5d7f2c5..cdd3fcf 100644
--- a/test/distributed/test_data_parallel.py
+++ b/test/distributed/test_data_parallel.py
@@ -644,9 +644,10 @@
                 self._testcase = testcase
 
             def forward(self, x):
-                self._testcase.assertWarnsRegex(
-                    lambda: self.zero_grad(),
-                    r"Calling \.zero_grad\(\) from a module that was passed to a nn\.DataParallel\(\) has no effect.")
+                with self._testcase.assertWarnsRegex(
+                        UserWarning,
+                        r"Calling \.zero_grad\(\) from a module that was passed to a nn\.DataParallel\(\) has no effect."):
+                    self.zero_grad()
                 return x
 
         module = Net(self).cuda()
diff --git a/test/test_dataloader.py b/test/test_dataloader.py
index 2918411..c18f12d 100644
--- a/test/test_dataloader.py
+++ b/test/test_dataloader.py
@@ -320,6 +320,8 @@
 set_faulthander_if_available()
 
 # Process `pid` must have called `set_faulthander_if_available`
+
+
 def print_traces_of_all_threads(pid):
     if HAS_FAULTHANDLER:
         if not IS_WINDOWS:
@@ -902,7 +904,6 @@
         with self.assertRaisesRegex(ValueError, "timeout option should be non-negative"):
             DataLoader(self.dataset, timeout=-1)
 
-
         # disable auto-batching
         with self.assertRaisesRegex(ValueError,
                                     "batch_size=None option disables auto-batching and is mutually exclusive"):
@@ -1024,10 +1025,11 @@
         for _ in range(20):
             self.assertNotWarn(lambda: next(it), "Should not warn before exceeding length")
         for _ in range(3):
-            self.assertWarnsRegex(
-                lambda: next(it),
+            with self.assertWarnsRegex(
+                UserWarning,
                 r"but [0-9]+ samples have been fetched\. For multiprocessing data-loading, this",
-                "Should always warn after exceeding length")
+                    msg="Should always warn after exceeding length"):
+                next(it)
 
         # [no auto-batching] test that workers exit gracefully
         workers = dataloader_iter._workers
diff --git a/test/test_jit.py b/test/test_jit.py
index 0aeebc1..1486049 100644
--- a/test/test_jit.py
+++ b/test/test_jit.py
@@ -406,7 +406,7 @@
             def __init__(self, cpu_device_str):
                 super(M, self).__init__()
                 self.p0 = nn.Parameter(torch.tensor([0.3], dtype=torch.float,
-                                       device=cpu_device_str))
+                                                    device=cpu_device_str))
                 self.b0 = torch.tensor([0.9], dtype=torch.float,
                                        device=cpu_device_str)
 
@@ -14263,12 +14263,6 @@
             torch.testing.assert_allclose(y_int8, y_ref, rtol=0.0001, atol=1e-3)
             torch.testing.assert_allclose(y_fp16, y_ref, rtol=0.0001, atol=1e-3)
 
-    def checkTracerWarning(self, *args, **kwargs):
-        with warnings.catch_warnings(record=True) as warns:
-            torch.jit.trace(*args, **kwargs)
-        self.assertGreater(len(warns), 0)
-        self.assertTrue(any(["cause the trace to be incorrect" in str(warn.message) for warn in warns]))
-
     def test_trace_checker_slice_lhs(self):
         def foo(x):
             for i in range(3):
@@ -14282,18 +14276,21 @@
             x.view(-1).add_(-x.view(-1))
             return x
 
-        self.assertWarnsRegex(lambda: torch.jit.trace(foo,
-                                                      torch.rand(3, 4),
-                                                      check_inputs=[torch.rand(5, 6)],
-                                                      _force_outplace=True),
-                              'Output nr 1. of the traced function does not match the '
-                              'corresponding output of the Python function')
+        with self.assertWarnsRegex(torch.jit.TracerWarning,
+                                   'Output nr 1. of the traced function does not match the '
+                                   'corresponding output of the Python function'):
+            torch.jit.trace(foo,
+                            torch.rand(3, 4),
+                            check_inputs=[torch.rand(5, 6)],
+                            _force_outplace=True)
 
     def test_lhs_index_fails(self):
         def foo(x):
             x[0, 1] = 4
             return x
-        self.checkTracerWarning(foo, torch.rand(3, 4), _force_outplace=True)
+
+        with self.assertWarnsRegex(torch.jit.TracerWarning, "cause the trace to be incorrect"):
+            torch.jit.trace(foo, torch.rand(3, 4), _force_outplace=True)
 
     def test_lhs_index_trivial(self):
         def foo(y, x):
@@ -14305,18 +14302,23 @@
         def foo(x):
             x.view(-1).add_(-x.view(-1))
             return x
-        self.checkTracerWarning(foo, torch.rand(3, 4), _force_outplace=True)
+
+        with self.assertWarnsRegex(torch.jit.TracerWarning, "cause the trace to be incorrect"):
+            torch.jit.trace(foo, torch.rand(3, 4), _force_outplace=True)
 
     @suppress_warnings
     def test_trace_checker_dropout_train(self):
         def foo(x):
             return torch.dropout(x, p=0.5, train=True)
 
-        self.assertWarnsRegex(lambda: torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(5, 6)]),
-                              'Output nr 1. of the traced function does not match the '
-                              'corresponding output of the Python function')
-        self.assertWarnsRegex(lambda: torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(5, 6)]),
-                              'Trace had nondeterministic nodes')
+        with self.assertWarnsRegex(torch.jit.TracerWarning,
+                                   'Output nr 1. of the traced function does not match the '
+                                   'corresponding output of the Python function'):
+            torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(5, 6)])
+
+        with self.assertWarnsRegex(torch.jit.TracerWarning,
+                                   'Trace had nondeterministic nodes'):
+            torch.jit.trace(foo, torch.rand(3, 4), check_inputs=[torch.rand(5, 6)])
 
     def test_trace_checker_dropout_notrain(self):
         input = torch.rand(3, 4)
diff --git a/test/test_mkldnn.py b/test/test_mkldnn.py
index 2d21d1c..93ddd1f 100644
--- a/test/test_mkldnn.py
+++ b/test/test_mkldnn.py
@@ -67,10 +67,12 @@
 
         # because MKLDNN only supports float32, we need to lessen the precision.
         # these numbers are just empirical results that seem to work.
-        self.assertWarnsRegex(lambda: gradcheck(func, [root], atol=4e-2, rtol=1e-2),
-                              'double precision floating point')
-        self.assertWarnsRegex(lambda: gradgradcheck(func, [root], atol=4e-2, rtol=1e-2),
-                              'double precision floating point')
+        self.assertWarnsRegex(UserWarning,
+                              'double precision floating point',
+                              lambda: gradcheck(func, [root], atol=4e-2, rtol=1e-2))
+        self.assertWarnsRegex(UserWarning,
+                              'double precision floating point',
+                              lambda: gradgradcheck(func, [root], atol=4e-2, rtol=1e-2))
 
     def test_autograd_from_mkldnn(self):
         # MKLDNN only supports float32
@@ -81,8 +83,9 @@
 
         # because MKLDNN only supports float32, we need to lessen the precision.
         # these numbers are just empirical results that seem to work.
-        self.assertWarnsRegex(lambda: gradcheck(func, [root], atol=4e-2, rtol=1e-2),
-                              'double precision floating point')
+        self.assertWarnsRegex(UserWarning,
+                              'double precision floating point',
+                              lambda: gradcheck(func, [root], atol=4e-2, rtol=1e-2))
 
     def test_detach(self):
         root = torch.randn(4, 5, dtype=torch.float32).to_mkldnn().requires_grad_()
diff --git a/test/test_nn.py b/test/test_nn.py
index 6aaa480..1bf4154 100644
--- a/test/test_nn.py
+++ b/test/test_nn.py
@@ -8368,7 +8368,9 @@
 
         def fn():
             init.normal(x)
-        self.assertWarnsRegex(fn, 'deprecated', 'methods not suffixed with underscore should be deprecated')
+
+        with self.assertWarnsRegex(UserWarning, 'deprecated', msg='methods not suffixed with underscore should be deprecated'):
+            fn()
 
 class TestFusionEval(TestCase):
     @given(X=hu.tensor(shapes=((5, 3, 5, 5),)),
diff --git a/test/test_optim.py b/test/test_optim.py
index 22eeaf7..b0d502b 100644
--- a/test/test_optim.py
+++ b/test/test_optim.py
@@ -575,7 +575,7 @@
                 scheduler.step()
                 self.opt.step()
 
-        self.assertWarnsRegex(old_pattern, r'how-to-adjust-learning-rate')
+        self.assertWarnsRegex(UserWarning, r'how-to-adjust-learning-rate', old_pattern)
 
     def test_old_pattern_warning_with_arg(self):
         epochs = 35
@@ -589,7 +589,7 @@
                 scheduler.step()
                 self.opt.step()
 
-        self.assertWarnsRegex(old_pattern2, r'how-to-adjust-learning-rate')
+        self.assertWarnsRegex(UserWarning, r'how-to-adjust-learning-rate', old_pattern2)
 
     def test_old_pattern_warning_resuming(self):
         epochs = 35
@@ -606,7 +606,7 @@
                 scheduler.step()
                 self.opt.step()
 
-        self.assertWarnsRegex(old_pattern, r'how-to-adjust-learning-rate')
+        self.assertWarnsRegex(UserWarning, r'how-to-adjust-learning-rate', old_pattern)
 
     def test_old_pattern_warning_resuming_with_arg(self):
         epochs = 35
@@ -623,7 +623,7 @@
                 scheduler.step()
                 self.opt.step()
 
-        self.assertWarnsRegex(old_pattern2, r'how-to-adjust-learning-rate')
+        self.assertWarnsRegex(UserWarning, r'how-to-adjust-learning-rate', old_pattern2)
 
     def test_old_pattern_warning_with_overridden_optim_step(self):
         epochs = 35
@@ -651,7 +651,7 @@
                 scheduler.step()
                 self.opt.step()
 
-        self.assertWarnsRegex(old_pattern2, r'how-to-adjust-learning-rate')
+        self.assertWarnsRegex(UserWarning, r'how-to-adjust-learning-rate', old_pattern2)
 
     def test_new_pattern_no_warning(self):
         epochs = 35
@@ -704,7 +704,7 @@
                 self.opt.step()
                 scheduler.step()
 
-        self.assertWarnsRegex(new_pattern, r'`optimizer.step\(\)` has been overridden')
+        self.assertWarnsRegex(UserWarning, r'`optimizer.step\(\)` has been overridden', new_pattern)
 
     def _test_lr_is_constant_for_constant_epoch(self, scheduler):
         l = []
diff --git a/test/test_torch.py b/test/test_torch.py
index cb35f75..3a1b4e7 100644
--- a/test/test_torch.py
+++ b/test/test_torch.py
@@ -527,7 +527,7 @@
         self.assertEqual(n.shape, t.shape)
         if t.dtype == torch.float:
             self.assertTrue(np.allclose(n, t.numpy(), rtol=1e-03, atol=1e-05,
-                            equal_nan=True))
+                                        equal_nan=True))
         else:
             self.assertTrue(np.allclose(n, t.numpy(), equal_nan=True))
 
@@ -549,29 +549,29 @@
                                                               dim).cpu(),
                                                    expected)
         do_one(self._make_tensors((5, 400000), use_floating=use_floating,
-               use_integral=use_integral), 1)
+                                  use_integral=use_integral), 1)
         do_one(self._make_tensors((3, 5, 7), use_floating=use_floating,
-               use_integral=use_integral), 0)
+                                  use_integral=use_integral), 0)
         do_one(self._make_tensors((3, 5, 7), use_floating=use_floating,
-               use_integral=use_integral), 1)
+                                  use_integral=use_integral), 1)
         do_one(self._make_tensors((3, 5, 7), use_floating=use_floating,
-               use_integral=use_integral), 2)
+                                  use_integral=use_integral), 2)
         do_one(self._make_tensors((100000, ), use_floating=use_floating,
-               use_integral=use_integral), -1)
+                                  use_integral=use_integral), -1)
         do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
-               use_integral=use_integral), 0)
+                                  use_integral=use_integral), 0)
         do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
-               use_integral=use_integral), 1)
+                                  use_integral=use_integral), 1)
         do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
-               use_integral=use_integral), 2)
+                                  use_integral=use_integral), 2)
         do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
-               use_integral=use_integral), (1, 2))
+                                  use_integral=use_integral), (1, 2))
         do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
-               use_integral=use_integral), (1, -1))
+                                  use_integral=use_integral), (1, -1))
         do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
-               use_integral=use_integral), (0, 2))
+                                  use_integral=use_integral), (0, 2))
         do_one(self._make_tensors((50, 50, 50), use_floating=use_floating,
-               use_integral=use_integral), (0, 2, 1))
+                                  use_integral=use_integral), (0, 2, 1))
 
     @slowTest
     @unittest.skipIf(not TEST_NUMPY, 'Numpy not found')
@@ -1454,9 +1454,8 @@
 
         # Test Rounding Errors
         line = torch.zeros(size=(1, 49))
-        self.assertWarnsRegex(lambda: torch.arange(-1, 1, 2. / 49, dtype=torch.float32, out=line),
-                              'resized',
-                              'The out tensor will be resized')
+        self.assertWarnsRegex(UserWarning, 'The out tensor will be resized',
+                              lambda: torch.arange(-1, 1, 2. / 49, dtype=torch.float32, out=line))
         self.assertEqual(line.shape, [50])
 
         x = torch.empty(1).expand(10)
@@ -2274,7 +2273,7 @@
     def test_numpy_non_writeable(self):
         arr = np.zeros(5)
         arr.flags['WRITEABLE'] = False
-        self.assertWarns(lambda: torch.from_numpy(arr))
+        self.assertWarns(UserWarning, lambda: torch.from_numpy(arr))
 
     @unittest.skipIf(not TEST_NUMPY, "Numpy not found")
     def test_empty_storage_view(self):
@@ -5165,9 +5164,8 @@
             # inputs are non-expandable tensors, but they have same number of elements
             # TORCH_WARN_ONCE is used in torch.normal, only 1st assertEqual will show warn msg
             if not warned:
-                self.assertWarnsRegex(
-                    lambda: self.assertEqual(torch.normal(tensor120, tensor2345).size(), (120,)),
-                    "deprecated and the support will be removed")
+                self.assertWarnsRegex(UserWarning, "deprecated and the support will be removed",
+                                      lambda: self.assertEqual(torch.normal(tensor120, tensor2345).size(), (120,)))
                 warned = True
             else:
                 self.assertEqual(torch.normal(tensor120, tensor2345).size(), (120,))
@@ -7836,19 +7834,19 @@
                    [1, 2]])
         columns = [0],
         self.assertEqual(reference[rows, columns], torch.tensor([[1, 1],
-                                                                [3, 5]], dtype=dtype, device=device))
+                                                                 [3, 5]], dtype=dtype, device=device))
 
         rows = ri([[0, 0],
                    [1, 2]])
         columns = ri([1, 0])
         self.assertEqual(reference[rows, columns], torch.tensor([[2, 1],
-                                                                [4, 5]], dtype=dtype, device=device))
+                                                                 [4, 5]], dtype=dtype, device=device))
         rows = ri([[0, 0],
                    [1, 2]])
         columns = ri([[0, 1],
                       [1, 0]])
         self.assertEqual(reference[rows, columns], torch.tensor([[1, 2],
-                                                                [4, 5]], dtype=dtype, device=device))
+                                                                 [4, 5]], dtype=dtype, device=device))
 
         # setting values
         reference[ri([0]), ri([1])] = -1
@@ -10393,8 +10391,8 @@
         self.assertEqual(res1, res2)
 
         a = torch.tensor([[True, False, True],
-                         [False, False, False],
-                         [True, True, True]], dtype=torch.bool, device=device)
+                          [False, False, False],
+                          [True, True, True]], dtype=torch.bool, device=device)
         b = a.byte()
         aRes = torch.cumprod(a, 0)
         bRes = torch.cumprod(b, 0)
@@ -10444,8 +10442,8 @@
             self.assertEqual(out1[1], indices2)
 
             a = torch.tensor([[True, False, True],
-                             [False, False, False],
-                             [True, True, True]], dtype=torch.bool, device=device)
+                              [False, False, False],
+                              [True, True, True]], dtype=torch.bool, device=device)
             b = a.byte()
             aRes = op(a, 0)
             bRes = op(b, 0)
@@ -12674,7 +12672,7 @@
         self.assertEqual(sort_topk, a[topk[1]])   # check indices
 
     @dtypesIfCUDA(*([torch.half, torch.float, torch.double]
-                  + ([torch.bfloat16] if TEST_WITH_ROCM else [])))
+                    + ([torch.bfloat16] if TEST_WITH_ROCM else [])))
     @dtypes(torch.float, torch.double)
     def test_topk_nonfinite(self, device, dtype):
         x = torch.tensor([float('nan'), float('inf'), 1e4, 0, -1e4, -float('inf')], device=device, dtype=dtype)
@@ -13073,7 +13071,7 @@
         self.assertEqual(input_values.erf().erfinv(), input_values)
         # test inf
         self.assertTrue(torch.equal(torch.tensor([-1, 1], dtype=dtype, device=device).erfinv(),
-                        torch.tensor([-inf, inf], dtype=dtype, device=device)))
+                                    torch.tensor([-inf, inf], dtype=dtype, device=device)))
         # test nan
         self.assertEqual(torch.tensor([-2, 2], dtype=dtype, device=device).erfinv(),
                          torch.tensor([nan, nan], dtype=dtype, device=device))
@@ -14976,15 +14974,11 @@
                 if to_ > from_:
                     if not (min_val <= from_ <= max_val) or not (min_val <= (to_ - 1) <= max_val):
                         if not (min_val <= from_ <= max_val):
-                            self.assertWarnsRegex(
-                                lambda: t.random_(from_, to_),
-                                "from is out of bounds"
-                            )
+                            self.assertWarnsRegex(UserWarning, "from is out of bounds",
+                                                  lambda: t.random_(from_, to_))
                         if not (min_val <= (to_ - 1) <= max_val):
-                            self.assertWarnsRegex(
-                                lambda: t.random_(from_, to_),
-                                "to - 1 is out of bounds"
-                            )
+                            self.assertWarnsRegex(UserWarning, "to - 1 is out of bounds",
+                                                  lambda: t.random_(from_, to_))
                     else:
                         t.random_(from_, to_)
                         range_ = to_ - from_
@@ -15078,15 +15072,11 @@
                 if to_ > from_:
                     if not (min_val <= from_ <= max_val) or not (min_val <= (to_ - 1) <= max_val):
                         if not (min_val <= from_ <= max_val):
-                            self.assertWarnsRegex(
-                                lambda: t.random_(from_, to_),
-                                "from is out of bounds"
-                            )
+                            self.assertWarnsRegex(UserWarning, "from is out of bounds",
+                                                  lambda: t.random_(from_, to_))
                         if not (min_val <= (to_ - 1) <= max_val):
-                            self.assertWarnsRegex(
-                                lambda: t.random_(from_, to_),
-                                "to - 1 is out of bounds"
-                            )
+                            self.assertWarnsRegex(UserWarning, "to - 1 is out of bounds",
+                                                  lambda: t.random_(from_, to_))
                     else:
                         t.random_(from_, to_)
                         range_ = to_ - from_
@@ -15147,10 +15137,8 @@
             t = torch.empty(size, dtype=dtype, device=device)
             if to_ > from_:
                 if not (min_val <= (to_ - 1) <= max_val):
-                    self.assertWarnsRegex(
-                        lambda: t.random_(to_),
-                        "to - 1 is out of bounds"
-                    )
+                    self.assertWarnsRegex(UserWarning, "to - 1 is out of bounds",
+                                          lambda: t.random_(to_))
                 else:
                     t.random_(to_)
                     range_ = to_ - from_
@@ -16313,7 +16301,7 @@
 
 def _small_3d_unique(dtype, device):
     return (torch.randperm(_S * _S * _S,
-            dtype=_convert_t(dtype, device), device=device) + 1).view(_S, _S, _S)
+                           dtype=_convert_t(dtype, device), device=device) + 1).view(_S, _S, _S)
 
 def _medium_1d(dtype, device):
     return _make_tensor((_M,), dtype, device)
@@ -16874,14 +16862,14 @@
                   _TorchMathTestMeta('round'),
                   _TorchMathTestMeta('lgamma', reffn='gammaln', ref_backend='scipy'),
                   _TorchMathTestMeta('polygamma', args=[0], substr='_0', reffn='polygamma',
-                  refargs=lambda x: (0, x.numpy()), input_fn=_generate_gamma_input, inputargs=[False],
-                  ref_backend='scipy'),
+                                     refargs=lambda x: (0, x.numpy()), input_fn=_generate_gamma_input, inputargs=[False],
+                                     ref_backend='scipy'),
                   _TorchMathTestMeta('polygamma', args=[1], substr='_1', reffn='polygamma',
-                  refargs=lambda x: (1, x.numpy()), input_fn=_generate_gamma_input, inputargs=[False],
-                  ref_backend='scipy', rtol=0.0008, atol=1e-5),
+                                     refargs=lambda x: (1, x.numpy()), input_fn=_generate_gamma_input, inputargs=[False],
+                                     ref_backend='scipy', rtol=0.0008, atol=1e-5),
                   _TorchMathTestMeta('digamma',
-                  input_fn=_generate_gamma_input, inputargs=[True], ref_backend='scipy',
-                  replace_inf_with_nan=True)]
+                                     input_fn=_generate_gamma_input, inputargs=[True], ref_backend='scipy',
+                                     replace_inf_with_nan=True)]
 
 
 def generate_torch_test_functions(cls, testmeta, inplace):
diff --git a/torch/testing/_internal/common_utils.py b/torch/testing/_internal/common_utils.py
index 0c5bd26..3b30990 100644
--- a/torch/testing/_internal/common_utils.py
+++ b/torch/testing/_internal/common_utils.py
@@ -1007,27 +1007,6 @@
             callable()
             self.assertTrue(len(ws) == 0, msg)
 
-    def assertWarns(self, callable, msg=''):
-        r"""
-        Test if :attr:`callable` raises a warning.
-        """
-        with self._reset_warning_registry(), warnings.catch_warnings(record=True) as ws:
-            warnings.simplefilter("always")  # allow any warning to be raised
-            callable()
-            self.assertTrue(len(ws) > 0, msg)
-
-    def assertWarnsRegex(self, callable, regex, msg=''):
-        r"""
-        Test if :attr:`callable` raises any warning with message that contains
-        the regex pattern :attr:`regex`.
-        """
-        with self._reset_warning_registry(), warnings.catch_warnings(record=True) as ws:
-            warnings.simplefilter("always")  # allow any warning to be raised
-            callable()
-            self.assertTrue(len(ws) > 0, msg)
-            found = any(re.search(regex, str(w.message)) is not None for w in ws)
-            self.assertTrue(found, msg)
-
     @contextmanager
     def maybeWarnsRegex(self, category, regex=''):
         """Context manager for code that *may* warn, e.g. ``TORCH_WARN_ONCE``.