| import threading |
| import torch |
| from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast |
| from ..modules import Module |
| from torch.cuda._utils import _get_device_index |
| from torch.cuda.amp import autocast |
| from torch._utils import ExceptionWrapper |
| |
| __all__ = ['get_a_var', 'parallel_apply'] |
| |
| def get_a_var(obj: Union[torch.Tensor, List[Any], Tuple[Any, ...], Dict[Any, Any]]) -> Optional[torch.Tensor]: |
| if isinstance(obj, torch.Tensor): |
| return obj |
| |
| if isinstance(obj, (list, tuple)): |
| for result in map(get_a_var, obj): |
| if isinstance(result, torch.Tensor): |
| return result |
| if isinstance(obj, dict): |
| for result in map(get_a_var, obj.items()): |
| if isinstance(result, torch.Tensor): |
| return result |
| return None |
| |
| def parallel_apply( |
| modules: Sequence[Module], |
| inputs: Sequence[Any], |
| kwargs_tup: Optional[Sequence[Dict[str, Any]]] = None, |
| devices: Optional[Sequence[Optional[Union[int, torch.device]]]] = None, |
| ) -> List[Any]: |
| r"""Applies each `module` in :attr:`modules` in parallel on arguments |
| contained in :attr:`inputs` (positional) and :attr:`kwargs_tup` (keyword) |
| on each of :attr:`devices`. |
| |
| Args: |
| modules (Module): modules to be parallelized |
| inputs (tensor): inputs to the modules |
| devices (list of int or torch.device): CUDA devices |
| |
| :attr:`modules`, :attr:`inputs`, :attr:`kwargs_tup` (if given), and |
| :attr:`devices` (if given) should all have same length. Moreover, each |
| element of :attr:`inputs` can either be a single object as the only argument |
| to a module, or a collection of positional arguments. |
| """ |
| assert len(modules) == len(inputs), f'The number of modules {len(modules)} is not equal to the number of inputs {len(inputs)}' |
| if kwargs_tup is not None: |
| assert len(modules) == len(kwargs_tup) |
| else: |
| kwargs_tup = (cast(Dict[str, Any], {}),) * len(modules) |
| if devices is not None: |
| assert len(modules) == len(devices) |
| else: |
| devices = [None] * len(modules) |
| devices = [_get_device_index(x, True) for x in devices] |
| streams = [torch.cuda.current_stream(x) for x in devices] |
| lock = threading.Lock() |
| results = {} |
| grad_enabled, autocast_enabled = torch.is_grad_enabled(), torch.is_autocast_enabled() |
| |
| def _worker( |
| i: int, |
| module: Module, |
| input: Any, |
| kwargs: Dict[str, Any], |
| device: Optional[Union[int, torch.device]] = None, |
| stream: Optional[torch.cuda.Stream] = None, |
| ) -> None: |
| torch.set_grad_enabled(grad_enabled) |
| if device is None: |
| t = get_a_var(input) |
| if t is None: |
| with lock: |
| results[i] = ExceptionWrapper( |
| where="in replica {}, no device was provided and no tensor input was found; " |
| "device cannot be resolved".format(i)) |
| return |
| device = t.get_device() |
| if stream is None: |
| stream = torch.cuda.current_stream(device) |
| try: |
| with torch.cuda.device(device), torch.cuda.stream(stream), autocast(enabled=autocast_enabled): |
| # this also avoids accidental slicing of `input` if it is a Tensor |
| if not isinstance(input, (list, tuple)): |
| input = (input,) |
| output = module(*input, **kwargs) |
| with lock: |
| results[i] = output |
| except Exception: |
| with lock: |
| results[i] = ExceptionWrapper( |
| where=f"in replica {i} on device {device}") |
| |
| if len(modules) > 1: |
| threads = [threading.Thread(target=_worker, |
| args=(i, module, input, kwargs, device, stream)) |
| for i, (module, input, kwargs, device, stream) in |
| enumerate(zip(modules, inputs, kwargs_tup, devices, streams))] |
| |
| for thread in threads: |
| thread.start() |
| for thread in threads: |
| thread.join() |
| else: |
| _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0], streams[0]) |
| |
| outputs = [] |
| for i in range(len(inputs)): |
| output = results[i] |
| if isinstance(output, ExceptionWrapper): |
| output.reraise() |
| outputs.append(output) |
| return outputs |