blob: 89718981b4d60e7a38d263a6b99a7538461b86c8 [file] [log] [blame]
import torch
import warnings
warnings.warn("""
================================================================================
WARNING
================================================================================
torch.distributed is a highly experimental package. The API will change without
notice and we're can't guarantee full correctness and expected performance yet.
We'll announce it once it's ready.
""")
_initialized = False
_scope = locals()
def extend_scope(module):
_scope.update({k: getattr(module, k) for k in dir(module) if not k.startswith('_')})
def init_process_group(backend):
global _initialized
if _initialized:
raise RuntimeError("trying to initialize torch.distributed twice!")
torch._C._dist_init_process_group(backend)
_initialized = True
import torch.distributed.collectives as collectives
extend_scope(collectives)
assert torch._C._dist_init_extension(False, reduce_op, group)
def init_master_worker(backend):
global _initialized
if _initialized:
raise RuntimeError("trying to initialize torch.distributed twice!")
torch._C._dist_init_master_worker(backend)
_initialized = True
import torch.distributed.collectives as collectives
import torch.distributed.remote_types as remote_types
extend_scope(collectives)
extend_scope(remote_types)
assert torch._C._dist_init_extension(True, reduce_op, group)