| import threading |
| import torch |
| from torch.autograd import Variable |
| |
| |
| def get_a_var(obj): |
| if isinstance(obj, Variable): |
| return obj |
| |
| if isinstance(obj, list) or isinstance(obj, tuple): |
| results = map(get_a_var, obj) |
| for result in results: |
| if isinstance(result, Variable): |
| return result |
| if isinstance(obj, dict): |
| results = map(get_a_var, obj.items()) |
| for result in results: |
| if isinstance(result, Variable): |
| return result |
| return None |
| |
| |
| def parallel_apply(modules, inputs, kwargs_tup=None, devices=None): |
| assert len(modules) == len(inputs) |
| if kwargs_tup is not None: |
| assert len(modules) == len(kwargs_tup) |
| else: |
| kwargs_tup = ({},) * len(modules) |
| if devices is not None: |
| assert len(modules) == len(devices) |
| else: |
| devices = [None] * len(modules) |
| |
| lock = threading.Lock() |
| results = {} |
| grad_enabled = torch.is_grad_enabled() |
| |
| def _worker(i, module, input, kwargs, device=None): |
| torch.set_grad_enabled(grad_enabled) |
| if device is None: |
| device = get_a_var(input).get_device() |
| try: |
| with torch.cuda.device(device): |
| output = module(*input, **kwargs) |
| with lock: |
| results[i] = output |
| except Exception as e: |
| with lock: |
| results[i] = e |
| |
| if len(modules) > 1: |
| threads = [threading.Thread(target=_worker, |
| args=(i, module, input, kwargs, device)) |
| for i, (module, input, kwargs, device) in |
| enumerate(zip(modules, inputs, kwargs_tup, devices))] |
| |
| for thread in threads: |
| thread.start() |
| for thread in threads: |
| thread.join() |
| else: |
| _worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0]) |
| |
| outputs = [] |
| for i in range(len(inputs)): |
| output = results[i] |
| if isinstance(output, Exception): |
| raise output |
| outputs.append(output) |
| return outputs |