| import torch.cuda |
| import torch.cuda.comm as comm |
| from torch.autograd import Function |
| |
| class Broadcast(Function): |
| |
| def __init__(self, target_gpus): |
| super(Broadcast, self).__init__() |
| self.target_gpus = target_gpus |
| |
| def forward(self, input): |
| assert input.is_cuda, "Broadcast function not implemented for CPU tensors" |
| self.input_device = input.get_device() |
| return comm.broadcast(input, self.target_gpus) |
| |
| def backward(self, *grad_output): |
| return comm.reduce_add(grad_output, self.input_device) |
| |
| |
| class Gather(Function): |
| |
| def __init__(self, target_gpu, dim=0): |
| super(Gather, self).__init__() |
| self.target_gpu = target_gpu |
| self.dim = dim |
| |
| def forward(self, *inputs): |
| assert all(map(lambda i: i.is_cuda, inputs)) |
| self.input_gpus = tuple(map(lambda i: i.get_device(), inputs)) |
| self.input_sizes = tuple(map(lambda i: i.size(self.dim), inputs)) |
| return comm.gather(inputs, self.dim, self.target_gpu) |
| |
| def backward(self, grad_output): |
| return comm.scatter(grad_output, self.input_gpus, self.input_sizes, |
| self.dim) |
| |
| |
| class Scatter(Function): |
| |
| def __init__(self, target_gpus, chunk_sizes=None, dim=0): |
| super(Scatter, self).__init__() |
| self.target_gpus = target_gpus |
| self.chunk_sizes = chunk_sizes |
| self.dim = dim |
| |
| def forward(self, input): |
| self.input_device = input.get_device() if input.is_cuda else -1 |
| return comm.scatter(input, self.target_gpus, self.chunk_sizes, self.dim) |
| |
| def backward(self, grad_output): |
| return comm.gather(grad_output, self.dim, self.input_device) |
| |