| import torch.cuda.comm as comm |
| from torch.cuda._utils import _get_device_index |
| |
| |
| def _is_script_module(module): |
| import torch.jit |
| return isinstance(module, torch.jit.ScriptModule) |
| |
| |
| def _is_script_method(module): |
| import torch.jit |
| return isinstance(module, torch._C.ScriptMethod) |
| |
| |
| def _init_script_module(): |
| import torch.jit |
| return torch.jit.ScriptModule() |
| |
| |
| def _is_jit_enabled(): |
| import torch.jit |
| return torch.jit._enabled |
| |
| |
| # Check if we can safely replicate the module. |
| # there are three types of module: |
| # 1. python modules |
| # 2. weak python modules (nn.Module annotated by @weak_module) |
| # 3. ScriptModule |
| # |
| # currently a module cannot be replicated properly if the descendants of |
| # any ScriptModule contains python module (type 1 above) |
| def _replicatable_module(module, memo=None): |
| |
| # module.modules() contains module itself as the first element |
| def descendant_modules(module): |
| gen = module.modules() |
| next(gen) |
| return gen |
| |
| if not _is_jit_enabled(): |
| return True |
| if memo is None: |
| memo = set() |
| |
| # memorize visited modules |
| memo.add(module) |
| if _is_script_module(module): |
| memo.update(descendant_modules(module)) |
| return all(_is_script_module(descendant) for |
| descendant in descendant_modules(module)) |
| |
| for child in module.children(): |
| # since any unreplicatable module will cause the check to return |
| # False early, visited modules here can be safely ignored. |
| if child in memo: |
| continue |
| if not _replicatable_module(child, memo): |
| return False |
| |
| return True |
| |
| |
| def _copy_scriptmodule_methods(modules, module_copies, module_indices): |
| for i, module in enumerate(modules): |
| if not _is_script_module(module): |
| continue |
| replica = module_copies[i] |
| for method_name in module._c._method_names(): |
| replica._c.clone_method(module._c, method_name) |
| |
| |
| def _broadcast_coalesced_reshape(tensors, devices, detach=False): |
| from ._functions import Broadcast |
| if detach: |
| return comm.broadcast_coalesced(tensors, devices) |
| else: |
| # Use the autograd function to broadcast if not detach |
| if len(tensors) > 0: |
| tensor_copies = Broadcast.apply(devices, *tensors) |
| return [tensor_copies[i:i + len(tensors)] |
| for i in range(0, len(tensor_copies), len(tensors))] |
| else: |
| return [] |
| |
| |
| def replicate(network, devices, detach=False): |
| if not _replicatable_module(network): |
| raise RuntimeError("Cannot replicate network where python modules are " |
| "childrens of ScriptModule") |
| |
| devices = list(map(lambda x: _get_device_index(x, True), devices)) |
| num_replicas = len(devices) |
| |
| params = list(network.parameters()) |
| param_indices = {param: idx for idx, param in enumerate(params)} |
| param_copies = _broadcast_coalesced_reshape(params, devices, detach) |
| |
| buffers = list(network.buffers()) |
| buffers_rg = [] |
| buffers_not_rg = [] |
| for buf in buffers: |
| if buf.requires_grad and not detach: |
| buffers_rg.append(buf) |
| else: |
| buffers_not_rg.append(buf) |
| |
| buffer_indices_rg = {buf: idx for idx, buf in enumerate(buffers_rg)} |
| buffer_indices_not_rg = {buf: idx for idx, buf in enumerate(buffers_not_rg)} |
| |
| buffer_copies_rg = _broadcast_coalesced_reshape(buffers_rg, devices, detach=detach) |
| buffer_copies_not_rg = _broadcast_coalesced_reshape(buffers_not_rg, devices, detach=True) |
| |
| modules = list(network.modules()) |
| module_copies = [[] for device in devices] |
| module_indices = {} |
| scriptmodule_skip_attr = {"_parameters", "_buffers", "_modules", "forward", "_c"} |
| |
| for i, module in enumerate(modules): |
| module_indices[module] = i |
| for j in range(num_replicas): |
| if _is_script_module(module): |
| # we have to initialize ScriptModule properly so that |
| # it works with pybind11 |
| replica = _init_script_module() |
| |
| attribute_names = set(entry[0] for entry in module._c._get_attributes()) |
| |
| keys = set(module.__dict__.keys()) - scriptmodule_skip_attr - attribute_names |
| for key in keys: |
| if not _is_script_method(module.__dict__[key]): |
| replica.__dict__[key] = module.__dict__[key] |
| for name, the_type, value in module._c._get_attributes(): |
| if name in module._buffers.keys(): |
| continue |
| replica._c._register_attribute(name, the_type, value) |
| else: |
| replica = module.__new__(type(module)) |
| replica.__dict__ = module.__dict__.copy() |
| replica._parameters = replica._parameters.copy() |
| replica._buffers = replica._buffers.copy() |
| replica._modules = replica._modules.copy() |
| |
| module_copies[j].append(replica) |
| |
| for i, module in enumerate(modules): |
| for key, child in module._modules.items(): |
| if child is None: |
| for j in range(num_replicas): |
| replica = module_copies[j][i] |
| replica._modules[key] = None |
| else: |
| module_idx = module_indices[child] |
| for j in range(num_replicas): |
| replica = module_copies[j][i] |
| replica._modules[key] = module_copies[j][module_idx] |
| for key, param in module._parameters.items(): |
| if param is None: |
| for j in range(num_replicas): |
| replica = module_copies[j][i] |
| replica._parameters[key] = None |
| else: |
| param_idx = param_indices[param] |
| for j in range(num_replicas): |
| replica = module_copies[j][i] |
| replica._parameters[key] = param_copies[j][param_idx] |
| for key, buf in module._buffers.items(): |
| if buf is None: |
| for j in range(num_replicas): |
| replica = module_copies[j][i] |
| replica._buffers[key] = None |
| else: |
| if buf.requires_grad and not detach: |
| buffer_copies = buffer_copies_rg |
| buffer_idx = buffer_indices_rg[buf] |
| else: |
| buffer_copies = buffer_copies_not_rg |
| buffer_idx = buffer_indices_not_rg[buf] |
| for j in range(num_replicas): |
| replica = module_copies[j][i] |
| replica._buffers[key] = buffer_copies[j][buffer_idx] |
| |
| for j in range(num_replicas): |
| _copy_scriptmodule_methods(modules, module_copies[j], module_indices) |
| |
| return [module_copies[j][0] for j in range(num_replicas)] |