|  | import collections | 
|  | import warnings | 
|  |  | 
|  | import torch.cuda | 
|  | from typing import Optional, Sequence, Union | 
|  |  | 
|  |  | 
|  | __all__ = ['all_reduce', 'reduce', 'broadcast', 'all_gather', 'reduce_scatter'] | 
|  |  | 
|  | SUM = 0  # ncclRedOp_t | 
|  |  | 
|  |  | 
|  | def is_available(tensors): | 
|  | if not hasattr(torch._C, '_nccl_all_reduce'): | 
|  | warnings.warn('PyTorch is not compiled with NCCL support') | 
|  | return False | 
|  |  | 
|  | devices = set() | 
|  | for tensor in tensors: | 
|  | if tensor.is_sparse: | 
|  | return False | 
|  | if not tensor.is_contiguous(): | 
|  | return False | 
|  | if not tensor.is_cuda: | 
|  | return False | 
|  | device = tensor.get_device() | 
|  | if device in devices: | 
|  | return False | 
|  | devices.add(device) | 
|  |  | 
|  | return True | 
|  |  | 
|  |  | 
|  | def version(): | 
|  | ver = torch._C._nccl_version() | 
|  | major = ver >> 32 | 
|  | minor = (ver >> 16) & 65535 | 
|  | patch = ver & 65535 | 
|  | return (major, minor, patch) | 
|  |  | 
|  |  | 
|  | def unique_id(): | 
|  | return torch._C._nccl_unique_id() | 
|  |  | 
|  |  | 
|  | def init_rank(num_ranks, uid, rank): | 
|  | return torch._C._nccl_init_rank(num_ranks, uid, rank) | 
|  |  | 
|  |  | 
|  | def _check_sequence_type(inputs: Union[torch.Tensor, Sequence[torch.Tensor]]) -> None: | 
|  | if not isinstance(inputs, collections.abc.Container) or isinstance(inputs, torch.Tensor): | 
|  | raise TypeError("Inputs should be a collection of tensors") | 
|  |  | 
|  |  | 
|  | def all_reduce(inputs, outputs=None, op=SUM, streams=None, comms=None): | 
|  | _check_sequence_type(inputs) | 
|  | if outputs is None: | 
|  | outputs = inputs | 
|  | _check_sequence_type(outputs) | 
|  | torch._C._nccl_all_reduce(inputs, outputs, op, streams, comms) | 
|  |  | 
|  |  | 
|  | # `output` used to be `outputs`, taking in a list of tensors. So we have two | 
|  | # arguments for BC reasons. | 
|  | def reduce(inputs: Sequence[torch.Tensor], | 
|  | output: Optional[Union[torch.Tensor, Sequence[torch.Tensor]]] = None, | 
|  | root: int = 0, | 
|  | op: int = SUM, | 
|  | streams: Optional[Sequence[torch.cuda.Stream]] = None, | 
|  | comms=None, *, | 
|  | outputs: Optional[Sequence[torch.Tensor]] = None) -> None: | 
|  | _check_sequence_type(inputs) | 
|  | _output: torch.Tensor | 
|  | if outputs is not None: | 
|  | if output is not None: | 
|  | raise ValueError( | 
|  | "'output' and 'outputs' can not be both specified. 'outputs' is deprecated in " | 
|  | "favor of 'output', taking in a single output tensor. The signature of reduce is: " | 
|  | "reduce(inputs, output=None, root=0, op=SUM, streams=None, comms=None).") | 
|  | else: | 
|  | warnings.warn( | 
|  | "nccl.reduce with an output tensor list is deprecated. " | 
|  | "Please specify a single output tensor with argument 'output' instead instead.") | 
|  | _output = outputs[root] | 
|  | elif not isinstance(output, torch.Tensor) and isinstance(output, collections.abc.Sequence): | 
|  | # User called old API with positional arguments of list of output tensors. | 
|  | warnings.warn( | 
|  | "nccl.reduce with an output tensor list is deprecated. " | 
|  | "Please specify a single output tensor.") | 
|  | _output = output[root] | 
|  | else: | 
|  | _output = inputs[root] if output is None else output | 
|  | torch._C._nccl_reduce(inputs, _output, root, op, streams, comms) | 
|  |  | 
|  |  | 
|  | def broadcast(inputs: Sequence[torch.Tensor], root: int = 0, streams=None, comms=None) -> None: | 
|  | _check_sequence_type(inputs) | 
|  | torch._C._nccl_broadcast(inputs, root, streams, comms) | 
|  |  | 
|  |  | 
|  | def all_gather(inputs: Sequence[torch.Tensor], outputs: Sequence[torch.Tensor], streams=None, comms=None) -> None: | 
|  | _check_sequence_type(inputs) | 
|  | _check_sequence_type(outputs) | 
|  | torch._C._nccl_all_gather(inputs, outputs, streams, comms) | 
|  |  | 
|  |  | 
|  | def reduce_scatter(inputs: Sequence[torch.Tensor], | 
|  | outputs: Sequence[torch.Tensor], | 
|  | op: int = SUM, | 
|  | streams=None, comms=None) -> None: | 
|  | _check_sequence_type(inputs) | 
|  | _check_sequence_type(outputs) | 
|  | torch._C._nccl_reduce_scatter(inputs, outputs, op, streams, comms) |