| import threading |
| import time |
| import torch |
| from torch.futures import Future |
| from torch.testing._internal.common_utils import TestCase, TemporaryFileName |
| |
| |
| def add_one(fut): |
| return fut.wait() + 1 |
| |
| |
| class TestFuture(TestCase): |
| def test_wait(self): |
| f = Future() |
| f.set_result(torch.ones(2, 2)) |
| |
| self.assertEqual(f.wait(), torch.ones(2, 2)) |
| |
| def test_wait_multi_thread(self): |
| |
| def slow_set_future(fut, value): |
| time.sleep(0.5) |
| fut.set_result(value) |
| |
| f = Future() |
| |
| t = threading.Thread(target=slow_set_future, args=(f, torch.ones(2, 2))) |
| t.start() |
| |
| self.assertEqual(f.wait(), torch.ones(2, 2)) |
| t.join() |
| |
| def test_mark_future_twice(self): |
| fut = Future() |
| fut.set_result(1) |
| with self.assertRaisesRegex( |
| RuntimeError, |
| "Future can only be marked completed once" |
| ): |
| fut.set_result(1) |
| |
| def test_pickle_future(self): |
| fut = Future() |
| errMsg = "Can not pickle torch.futures.Future" |
| with TemporaryFileName() as fname: |
| with self.assertRaisesRegex(RuntimeError, errMsg): |
| torch.save(fut, fname) |
| |
| def test_then(self): |
| fut = Future() |
| then_fut = fut.then(lambda x: x.wait() + 1) |
| |
| fut.set_result(torch.ones(2, 2)) |
| self.assertEqual(fut.wait(), torch.ones(2, 2)) |
| self.assertEqual(then_fut.wait(), torch.ones(2, 2) + 1) |
| |
| def test_chained_then(self): |
| fut = Future() |
| futs = [] |
| last_fut = fut |
| for _ in range(20): |
| last_fut = last_fut.then(add_one) |
| futs.append(last_fut) |
| |
| fut.set_result(torch.ones(2, 2)) |
| |
| for i in range(len(futs)): |
| self.assertEqual(futs[i].wait(), torch.ones(2, 2) + i + 1) |
| |
| def _test_error(self, cb, errMsg): |
| fut = Future() |
| then_fut = fut.then(cb) |
| |
| fut.set_result(5) |
| self.assertEqual(5, fut.wait()) |
| with self.assertRaisesRegex(RuntimeError, errMsg): |
| then_fut.wait() |
| |
| def test_then_wrong_arg(self): |
| |
| def wrong_arg(tensor): |
| return tensor + 1 |
| |
| self._test_error(wrong_arg, "unsupported operand type.*Future.*int") |
| |
| def test_then_no_arg(self): |
| |
| def no_arg(): |
| return True |
| |
| self._test_error(no_arg, "takes 0 positional arguments but 1 was given") |
| |
| def test_then_raise(self): |
| |
| def raise_value_error(fut): |
| raise ValueError("Expected error") |
| |
| self._test_error(raise_value_error, "Expected error") |
| |
| def test_collect_all(self): |
| fut1 = Future() |
| fut2 = Future() |
| fut_all = torch.futures.collect_all([fut1, fut2]) |
| |
| def slow_in_thread(fut, value): |
| time.sleep(0.1) |
| fut.set_result(value) |
| |
| t = threading.Thread(target=slow_in_thread, args=(fut1, 1)) |
| fut2.set_result(2) |
| t.start() |
| |
| res = fut_all.wait() |
| self.assertEqual(res[0].wait(), 1) |
| self.assertEqual(res[1].wait(), 2) |
| t.join() |
| |
| def test_wait_all(self): |
| fut1 = Future() |
| fut2 = Future() |
| |
| # No error version |
| fut1.set_result(1) |
| fut2.set_result(2) |
| res = torch.futures.wait_all([fut1, fut2]) |
| print(res) |
| self.assertEqual(res, [1, 2]) |
| |
| # Version with an exception |
| def raise_in_fut(fut): |
| raise ValueError("Expected error") |
| fut3 = fut1.then(raise_in_fut) |
| with self.assertRaisesRegex(RuntimeError, "Expected error"): |
| torch.futures.wait_all([fut3, fut2]) |