| import collections | 
 | import unittest | 
 |  | 
 | import torch | 
 | from torch.testing._internal.common_utils import ( | 
 |     TestCase, run_tests, TEST_WITH_ASAN) | 
 |  | 
 | try: | 
 |     import psutil | 
 |     HAS_PSUTIL = True | 
 | except ImportError: | 
 |     HAS_PSUTIL = False | 
 |  | 
 | device = torch.device('cpu') | 
 |  | 
 |  | 
 | class Network(torch.nn.Module): | 
 |     maxp1 = torch.nn.MaxPool2d(1, 1) | 
 |  | 
 |     def forward(self, x): | 
 |         return self.maxp1(x) | 
 |  | 
 |  | 
 | @unittest.skipIf(not HAS_PSUTIL, "Requires psutil to run") | 
 | @unittest.skipIf(TEST_WITH_ASAN, "Cannot test with ASAN") | 
 | class TestOpenMP_ParallelFor(TestCase): | 
 |     batch = 20 | 
 |     channels = 1 | 
 |     side_dim = 80 | 
 |     x = torch.randn([batch, channels, side_dim, side_dim], device=device) | 
 |     model = Network() | 
 |     ncores = min(5, psutil.cpu_count(logical=False)) | 
 |  | 
 |     def func(self, runs): | 
 |         p = psutil.Process() | 
 |         # warm up for 5 runs, then things should be stable for the last 5 | 
 |         last_rss = collections.deque(maxlen=5) | 
 |         for n in range(10): | 
 |             for i in range(runs): | 
 |                 self.model(self.x) | 
 |             last_rss.append(p.memory_info().rss) | 
 |         return last_rss | 
 |  | 
 |     def func_rss(self, runs): | 
 |         last_rss = list(self.func(runs)) | 
 |         # Check that the sequence is not strictly increasing | 
 |         is_increasing = True | 
 |         for idx in range(len(last_rss)): | 
 |             if idx == 0: | 
 |                 continue | 
 |             is_increasing = is_increasing and (last_rss[idx] > last_rss[idx - 1]) | 
 |         self.assertTrue(not is_increasing, | 
 |                         msg='memory usage is increasing, {}'.format(str(last_rss))) | 
 |  | 
 |     def test_one_thread(self): | 
 |         """Make sure there is no memory leak with one thread: issue gh-32284 | 
 |         """ | 
 |         torch.set_num_threads(1) | 
 |         self.func_rss(300) | 
 |  | 
 |     def test_n_threads(self): | 
 |         """Make sure there is no memory leak with many threads | 
 |         """ | 
 |         torch.set_num_threads(self.ncores) | 
 |         self.func_rss(300) | 
 |  | 
 | if __name__ == '__main__': | 
 |     run_tests() |