blob: 622f8975e04d6e4e9e0608f09c6c459deb01bb5b [file] [log] [blame]
from copy import copy
from collections import OrderedDict
from ..modules.container import Container
import torch.cuda.comm as comm
def _replicate_module(module, gpu, param_remap):
if module is None:
return module
replica = copy(module)
replica._parameters = OrderedDict()
for key, param in module._parameters.items():
replica._parameters[key] = param_remap.get(param)
replica._buffers = {}
for key, buffer in module._buffers.items():
replica._buffers[key] = param_remap.get(buffer)
if isinstance(replica, Container):
replica._modules = OrderedDict()
for name, child in module._modules.items():
replica._modules[name] = _replicate_module(child, gpu, param_remap)
return replica
def replicate(module, device_ids):
from ._functions import Broadcast
seen_params = set()
param_remap = [{} for dev_id in device_ids]
for param in module.parameters():
if param in seen_params:
continue
seen_params.add(param)
param_copies = Broadcast(device_ids)(param)
for param_copy, remap in zip(param_copies, param_remap):
remap[param] = param_copy
for m in module.modules():
for buffer in m._buffers.values():
copies = comm.broadcast(buffer, device_ids)
for buf_copy, remap in zip(copies, param_remap):
remap[buffer] = buf_copy
return [_replicate_module(module, device_id, remap)
for device_id, remap in zip(device_ids, param_remap)]