| from torch import nn |
| from typing import List |
| |
| def partition_model( |
| module: nn.Sequential, |
| balance: List[int], |
| devices: List[int] = None): |
| """ |
| Given an :class:`nn.Sequential <torch.nn.Sequential>` module, partitions |
| the model across multiple GPU devices according the provided ``balance`` |
| and ``devices``. |
| |
| Args: |
| module (:class:`nn.Sequential <torch.nn.Sequential>`): |
| Sequential model representing the pipe. |
| balance (List[int]): |
| List indicating the number of layers in each partition. |
| devices (List[int], optional): |
| List indicating the device to use for each partition. Defaults to |
| ``range(len(balance))`` |
| """ |
| device_idx = 0 |
| pipe_idx = 0 |
| balanced_pipe = [] |
| for num_layers in balance: |
| layers = [] |
| for i in range(num_layers): |
| layers.append(module[pipe_idx]) |
| pipe_idx += 1 |
| device = device_idx if devices is None else devices[device_idx] |
| balanced_pipe.append(nn.Sequential(*layers).to(device)) |
| device_idx += 1 |
| |
| return nn.Sequential(*balanced_pipe) |