#!/usr/bin/env python3 | |
from __future__ import absolute_import, division, print_function, unicode_literals | |
import torch | |
# rpc_fork tests use double as the default dtype | |
torch.set_default_dtype(torch.double) | |
from rpc_test import RpcTest | |
from common_distributed import MultiProcessTestCase | |
from common_utils import run_tests | |
class RpcTestWithFork(MultiProcessTestCase, RpcTest): | |
def setUp(self): | |
super(RpcTestWithFork, self).setUp() | |
self._fork_processes() | |
if __name__ == '__main__': | |
run_tests() |