blob: e57b4c3803b57f14f81f1c111f04ce2dd9344e38 [file] [log] [blame]
#!/usr/bin/env python3
from __future__ import absolute_import, division, print_function, unicode_literals
import torch
# dist_autograd_fork tests use double as the default dtype
torch.set_default_dtype(torch.double)
from dist_autograd_test import DistAutogradTest
from common_distributed import MultiProcessTestCase
from common_utils import run_tests
class DistAutogradTestWithFork(MultiProcessTestCase, DistAutogradTest):
def setUp(self):
super(DistAutogradTestWithFork, self).setUp()
self._fork_processes()
if __name__ == '__main__':
run_tests()