blob: 3480ac0e7b4d632f6c1ec7d320b99b6c09ccd61b [file] [log] [blame]
import torch
import warnings
warnings.warn("""
================================================================================
WARNING
================================================================================
torch.distributed is a highly experimental package. The API will change without
notice and we're can't guarantee full correctness and expected performance yet.
We'll announce it once it's ready.
""")
_initialized = False
_scope = locals()
def extend_scope(module):
_scope.update({k: getattr(module, k) for k in dir(module) if not k.startswith('_')})
def init_process_group(backend, init_method='env://', **kwargs):
world_size = kwargs.pop('world_size', -1)
group_name = kwargs.pop('group_name', '')
rank = kwargs.pop('rank', -1)
assert len(kwargs) == 0, "got unexpected keyword arguments: %s" % ",".join(kwargs.keys())
global _initialized
if _initialized:
raise RuntimeError("trying to initialize torch.distributed twice!")
torch._C._dist_init_process_group(backend, init_method, world_size,
group_name, rank)
_initialized = True
import torch.distributed.collectives as collectives
extend_scope(collectives)
assert torch._C._dist_init_extension(False, reduce_op, group)
def init_master_worker(backend, init_method='env://', **kwargs):
world_size = kwargs.pop('world_size', -1)
group_name = kwargs.pop('group_name', '')
rank = kwargs.pop('rank', -1)
assert len(kwargs) == 0, "got unexpected keyword arguments: %s" % ",".join(kwargs.keys())
global _initialized
if _initialized:
raise RuntimeError("trying to initialize torch.distributed twice!")
torch._C._dist_init_master_worker(backend, init_method, world_size,
group_name, rank)
_initialized = True
import torch.distributed.collectives as collectives
import torch.distributed.remote_types as remote_types
extend_scope(collectives)
extend_scope(remote_types)
assert torch._C._dist_init_extension(True, reduce_op, group)