blob: 3d504868aa46be53b10f559eab4d2843281ebbf4 [file] [log] [blame]
import os
import ctypes
import warnings
import torch.cuda
from torch.backends.cudnn import int_array
lib = None
__all__ = ['all_reduce', 'reduce', 'broadcast', 'all_gather', 'reduce_scatter']
def _libnccl():
global lib
if lib is None:
lib = ctypes.cdll.LoadLibrary(None)
if hasattr(lib, 'ncclCommDestroy'):
lib.ncclCommDestroy.restype = None
else:
lib = None
return lib
def is_available(tensors):
devices = set()
for tensor in tensors:
if not tensor.is_contiguous():
return False
if not tensor.is_cuda:
return False
device = tensor.get_device()
if device in devices:
return False
devices.add(device)
if _libnccl() is None:
warnings.warn('NCCL library not found. Check your LD_LIBRARY_PATH')
return False
return True
_communicators = {}
# ncclDataType_t
ncclChar = 0
ncclInt = 1
ncclHalf = 2
ncclFloat = 3
ncclDouble = 4
ncclInt64 = 5
ncclUint64 = 6
# ncclRedOp_t
SUM = 0
PROD = 1
MAX = 2
MIN = 3
status_codes = {
0: "Success",
1: "Unhandled Cuda Error",
2: "System Error",
3: "Internal Error",
4: "Invalid Device Pointer",
5: "Invalid Rank",
6: "Unsupported Device Count",
7: "Device Not Found",
8: "Invalid Device Index",
9: "Lib Wrapper Not Set",
10: "Cuda Malloc Failed",
11: "Rank Mismatch",
12: "Invalid Argument",
13: "Invalid Type",
14: "Invalid Operation",
}
nccl_types = {
'torch.cuda.ByteTensor': ncclChar,
'torch.cuda.CharTensor': ncclChar,
'torch.cuda.IntTensor': ncclInt,
'torch.cuda.HalfTensor': ncclHalf,
'torch.cuda.FloatTensor': ncclFloat,
'torch.cuda.DoubleTensor': ncclDouble,
'torch.cuda.LongTensor': ncclInt64,
}
class NcclError(RuntimeError):
def __init__(self, status):
self.status = status
msg = '{0} ({1})'.format(status_codes.get(status), status)
super(NcclError, self).__init__(msg)
class NcclComm(ctypes.c_void_p):
pass
class NcclCommList(object):
def __init__(self, devices):
self.devices = devices
ptrs = (NcclComm * len(devices))()
self._as_parameter_ = ptrs
check_error(lib.ncclCommInitAll(self, len(devices), int_array(devices)))
def __getitem__(self, i):
return self._as_parameter_[i]
def __del__(self):
for i in range(len(self.devices)):
lib.ncclCommDestroy(self[i])
def check_error(status):
if status != 0:
raise NcclError(status)
def communicator(inputs, outputs=None):
if _libnccl() is None:
raise RuntimeError('Unable to load NCCL library')
devices = [input.get_device() for input in inputs]
if outputs is not None:
for device, output in zip(devices, outputs):
if output.get_device() != device:
raise ValueError("inputs and outputs must be on the same devices")
key = ','.join(str(d) for d in devices)
if key not in _communicators:
_communicators[key] = NcclCommList(devices)
return _communicators[key]
def cudaStream():
# TODO: return the current stream
# ffi.C.THCState_getCurrentStream(cutorch.getState())
return None
def all_reduce(inputs, outputs=None, op=SUM):
if outputs is None:
outputs = inputs
_check_inputs(inputs, outputs)
comm = communicator(inputs, outputs)
count = inputs[0].numel()
data_type = nccl_types[inputs[0].type()]
with torch.cuda._free_mutex():
for i in range(len(inputs)):
with torch.cuda.device(comm.devices[i]):
check_error(lib.ncclAllReduce(
ctypes.c_void_p(inputs[i].data_ptr()),
ctypes.c_void_p(outputs[i].data_ptr()),
count, data_type, op, comm[i], cudaStream()))
def reduce(inputs, outputs=None, root=0, op=SUM):
assert(root >= 0 and root < len(inputs))
if outputs is None:
outputs = inputs
_check_inputs(inputs, outputs)
comm = communicator(inputs)
count = inputs[0].numel()
data_type = nccl_types[inputs[0].type()]
with torch.cuda._free_mutex():
for i in range(len(inputs)):
with torch.cuda.device(comm.devices[i]):
check_error(lib.ncclReduce(
ctypes.c_void_p(inputs[i].data_ptr()),
ctypes.c_void_p(outputs[i].data_ptr()), count,
data_type, op, root, comm[i], cudaStream()))
def broadcast(inputs, root=0):
assert(root >= 0 and root < len(inputs))
_check_inputs(inputs, inputs)
comm = communicator(inputs)
count = inputs[0].numel()
data_type = nccl_types[inputs[0].type()]
with torch.cuda._free_mutex():
for i in range(len(inputs)):
with torch.cuda.device(comm.devices[i]):
check_error(lib.ncclBcast(
ctypes.c_void_p(inputs[i].data_ptr()), count,
data_type, root, comm[i], cudaStream()))
def all_gather(inputs, outputs):
_check_inputs(inputs, outputs, len(inputs))
comm = communicator(inputs, outputs)
count = inputs[0].numel()
data_type = nccl_types[inputs[0].type()]
with torch.cuda._free_mutex():
for i in range(len(inputs)):
with torch.cuda.device(comm.devices[i]):
check_error(lib.ncclAllGather(
ctypes.c_void_p(inputs[i].data_ptr()), count, data_type,
ctypes.c_void_p(outputs[i].data_ptr()), comm[i],
cudaStream()))
def reduce_scatter(inputs, outputs, op=SUM):
_check_inputs(inputs, outputs, 1.0 / len(inputs))
comm = communicator(inputs, outputs)
count = inputs[0].numel() // len(inputs)
data_type = nccl_types[inputs[0].type()]
with torch.cuda._free_mutex():
for i in range(len(inputs)):
with torch.cuda.device(comm.devices[i]):
check_error(lib.ncclReduceScatter(
ctypes.c_void_p(inputs[i].data_ptr()),
ctypes.c_void_p(outputs[i].data_ptr()), count, data_type,
op, comm[i], cudaStream()))
def _check_inputs(inputs, outputs=None, size_multiplier=1):
devices = set()
size = inputs[0].numel()
if len(inputs) != len(outputs):
raise ValueError('inputs and outputs must be the same length')
for input, output in zip(inputs, outputs):
if not input.is_cuda:
raise TypeError('inputs must be CUDA inputs')
if not input.is_contiguous():
raise ValueError('inputs must be contiguous')
device = input.get_device()
if device in devices:
raise ValueError('inputs must be on unique devices')
devices.add(device)
if input.numel() != size:
raise ValueError('inputs must be the same size')
if not output.is_contiguous():
raise ValueError('outputs must be contiguous')
if output.get_device() != device:
raise ValueError('inputs and outputs must be on the same devices')
if output.numel() != size * size_multiplier:
raise ValueError(('incorrect output size; expected {0} but got {1}'
.format(size * size_multiplier, output.numel())))