blob: c496ea1cb1076997b9f3422abfba8d69e9d01523 [file] [log] [blame]
import threading
import time
import torch
from torch.futures import Future
from torch.testing._internal.common_utils import TestCase, TemporaryFileName
def add_one(fut):
return fut.wait() + 1
class TestFuture(TestCase):
def test_wait(self):
f = Future()
f.set_result(torch.ones(2, 2))
self.assertEqual(f.wait(), torch.ones(2, 2))
def test_wait_multi_thread(self):
def slow_set_future(fut, value):
time.sleep(0.5)
fut.set_result(value)
f = Future()
t = threading.Thread(target=slow_set_future, args=(f, torch.ones(2, 2)))
t.start()
self.assertEqual(f.wait(), torch.ones(2, 2))
t.join()
def test_mark_future_twice(self):
fut = Future()
fut.set_result(1)
with self.assertRaisesRegex(
RuntimeError,
"Future can only be marked completed once"
):
fut.set_result(1)
def test_pickle_future(self):
fut = Future()
errMsg = "Can not pickle torch.futures.Future"
with TemporaryFileName() as fname:
with self.assertRaisesRegex(RuntimeError, errMsg):
torch.save(fut, fname)
def test_then(self):
fut = Future()
then_fut = fut.then(lambda x: x.wait() + 1)
fut.set_result(torch.ones(2, 2))
self.assertEqual(fut.wait(), torch.ones(2, 2))
self.assertEqual(then_fut.wait(), torch.ones(2, 2) + 1)
def test_chained_then(self):
fut = Future()
futs = []
last_fut = fut
for _ in range(20):
last_fut = last_fut.then(add_one)
futs.append(last_fut)
fut.set_result(torch.ones(2, 2))
for i in range(len(futs)):
self.assertEqual(futs[i].wait(), torch.ones(2, 2) + i + 1)
def _test_error(self, cb, errMsg):
fut = Future()
then_fut = fut.then(cb)
fut.set_result(5)
self.assertEqual(5, fut.wait())
with self.assertRaisesRegex(RuntimeError, errMsg):
then_fut.wait()
def test_then_wrong_arg(self):
def wrong_arg(tensor):
return tensor + 1
self._test_error(wrong_arg, "unsupported operand type.*Future.*int")
def test_then_no_arg(self):
def no_arg():
return True
self._test_error(no_arg, "takes 0 positional arguments but 1 was given")
def test_then_raise(self):
def raise_value_error(fut):
raise ValueError("Expected error")
self._test_error(raise_value_error, "Expected error")
def test_collect_all(self):
fut1 = Future()
fut2 = Future()
fut_all = torch.futures.collect_all([fut1, fut2])
def slow_in_thread(fut, value):
time.sleep(0.1)
fut.set_result(value)
t = threading.Thread(target=slow_in_thread, args=(fut1, 1))
fut2.set_result(2)
t.start()
res = fut_all.wait()
self.assertEqual(res[0].wait(), 1)
self.assertEqual(res[1].wait(), 2)
t.join()
def test_wait_all(self):
fut1 = Future()
fut2 = Future()
# No error version
fut1.set_result(1)
fut2.set_result(2)
res = torch.futures.wait_all([fut1, fut2])
print(res)
self.assertEqual(res, [1, 2])
# Version with an exception
def raise_in_fut(fut):
raise ValueError("Expected error")
fut3 = fut1.then(raise_in_fut)
with self.assertRaisesRegex(RuntimeError, "Expected error"):
torch.futures.wait_all([fut3, fut2])