blob: a877bf62f2e695741e1c6a95775f732f5d554004 [file] [log] [blame]
from copy import copy
from collections import OrderedDict
from ..modules.container import Container
def _replicate_module(module, gpu, param_remap):
if module is None:
return module
replica = copy(module)
replica._parameters = OrderedDict()
for key, param in module._parameters.items():
replica._parameters[key] = param_remap.get(param)
if isinstance(replica, Container):
replica.modules = OrderedDict()
for name, child in module.modules.items():
replica.modules[name] = _replicate_module(child, gpu, param_remap)
return replica
def replicate(module, device_ids):
from .functions import Broadcast
seen_params = set()
param_remap = [{} for dev_id in device_ids]
for param in module.parameters():
if param in seen_params:
continue
seen_params.add(param)
param_copies = Broadcast(device_ids)(param)
for copy, remap in zip(param_copies, param_remap):
remap[param] = copy
return [_replicate_module(module, device_id, remap)
for device_id, remap in zip(device_ids, param_remap)]