| import torch |
| from ..modules import Module |
| from .scatter_gather import scatter_kwargs, gather |
| from .replicate import replicate |
| from .parallel_apply import parallel_apply |
| |
| |
| class DataParallel(Module): |
| |
| """Implements data parallelism at the module level. |
| |
| This container parallelizes the application of the given module by |
| splitting the input across the specified devices by chunking in the batch |
| dimension. In the forward pass, the module is replicated on each device, |
| and each replica handles a portion of the input. During the backwards |
| pass, gradients from each replica are summed into the original module. |
| |
| The batch size should be larger than the number of GPUs used. It should |
| also be an integer multiple of the number of GPUs so that each chunk is the |
| same size (so that each GPU processes the same number of samples). |
| |
| See also: :ref:`cuda-nn-dataparallel-instead` |
| |
| Arbitrary positional and keyword inputs are allowed to be passed into |
| DataParallel EXCEPT Tensors. All variables will be scattered on dim |
| specified (default 0). Primitive types will be broadcasted, but all |
| other types will be a shallow copy and can be corrupted if written to in |
| the model's forward pass. |
| |
| Args: |
| module: module to be parallelized |
| device_ids: CUDA devices (default: all devices) |
| output_device: device location of output (default: device_ids[0]) |
| |
| Example:: |
| |
| >>> net = torch.nn.DataParallel(model, device_ids=[0, 1, 2]) |
| >>> output = net(input_var) |
| """ |
| |
| # TODO: update notes/cuda.rst when this class handles 8+ GPUs well |
| |
| def __init__(self, module, device_ids=None, output_device=None, dim=0): |
| super(DataParallel, self).__init__() |
| if device_ids is None: |
| device_ids = list(range(torch.cuda.device_count())) |
| if output_device is None: |
| output_device = device_ids[0] |
| self.dim = dim |
| self.module = module |
| self.device_ids = device_ids |
| self.output_device = output_device |
| if len(self.device_ids) == 1: |
| self.module.cuda(device_ids[0]) |
| |
| def forward(self, *inputs, **kwargs): |
| inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) |
| if len(self.device_ids) == 1: |
| return self.module(*inputs[0], **kwargs[0]) |
| replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) |
| outputs = self.parallel_apply(replicas, inputs, kwargs) |
| return self.gather(outputs, self.output_device) |
| |
| def replicate(self, module, device_ids): |
| return replicate(module, device_ids) |
| |
| def scatter(self, inputs, kwargs, device_ids): |
| return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim) |
| |
| def parallel_apply(self, replicas, inputs, kwargs): |
| return parallel_apply(replicas, inputs, kwargs) |
| |
| def gather(self, outputs, output_device): |
| return gather(outputs, output_device, dim=self.dim) |
| |
| |
| def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, module_kwargs=None): |
| """Evaluates module(input) in parallel across the GPUs given in device_ids. |
| |
| This is the functional version of the DataParallel module. |
| |
| Args: |
| module: the module to evaluate in parallel |
| inputs: inputs to the module |
| device_ids: GPU ids on which to replicate module |
| output_device: GPU location of the output Use -1 to indicate the CPU. |
| (default: device_ids[0]) |
| Returns: |
| a Variable containing the result of module(input) located on |
| output_device |
| """ |
| if not isinstance(inputs, tuple): |
| inputs = (inputs,) |
| |
| if device_ids is None: |
| device_ids = list(range(torch.cuda.device_count())) |
| |
| if output_device is None: |
| output_device = device_ids[0] |
| |
| inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim) |
| if len(device_ids) == 1: |
| return module(*inputs[0], **module_kwargs[0]) |
| replicas = replicate(module, device_ids[:len(inputs)]) |
| outputs = parallel_apply(replicas, inputs, module_kwargs) |
| return gather(outputs, output_device, dim) |