blob: 43a868025858436b573f5f0872f9e8eac10fb5df [file] [log] [blame]
import threading
import torch
from torch.autograd import Variable
def get_a_var(obj):
if isinstance(obj, Variable):
return obj
if isinstance(obj, list) or isinstance(obj, tuple):
results = map(get_a_var, obj)
for result in results:
if isinstance(result, Variable):
return result
if isinstance(obj, dict):
results = map(get_a_var, obj.items())
for result in results:
if isinstance(result, Variable):
return result
return None
def parallel_apply(modules, inputs, kwargs_tup=None, devices=None):
assert len(modules) == len(inputs)
if kwargs_tup is not None:
assert len(modules) == len(kwargs_tup)
else:
kwargs_tup = ({},) * len(modules)
if devices is not None:
assert len(modules) == len(devices)
else:
devices = [None] * len(modules)
lock = threading.Lock()
results = {}
grad_enabled = torch.is_grad_enabled()
def _worker(i, module, input, kwargs, device=None):
torch.set_grad_enabled(grad_enabled)
if device is None:
device = get_a_var(input).get_device()
try:
with torch.cuda.device(device):
output = module(*input, **kwargs)
with lock:
results[i] = output
except Exception as e:
with lock:
results[i] = e
if len(modules) > 1:
threads = [threading.Thread(target=_worker,
args=(i, module, input, kwargs, device))
for i, (module, input, kwargs, device) in
enumerate(zip(modules, inputs, kwargs_tup, devices))]
for thread in threads:
thread.start()
for thread in threads:
thread.join()
else:
_worker(0, modules[0], inputs[0], kwargs_tup[0], devices[0])
outputs = []
for i in range(len(inputs)):
output = results[i]
if isinstance(output, Exception):
raise output
outputs.append(output)
return outputs