| import torch |
| import torch.distributed as dist |
| from torch.distributed import distributed_c10d |
| from torch.distributed._shard.sharded_tensor import ( |
| ShardedTensor, |
| ) |
| from .sharding_spec import ( |
| ShardingSpec, |
| ) |
| from .replicated_tensor import ReplicatedTensor |
| |
| def _shard_tensor( |
| tensor: torch.Tensor, sharding_spec: ShardingSpec, src_rank=0, process_group=None |
| ) -> ShardedTensor: |
| """ |
| Given a :class:`torch.Tensor`, it shards that tensor according to the provided |
| ``sharding_spec``. ``src_rank`` denotes the source rank which would be |
| used as the ground truth of the data which would be scattered as shards |
| across the rest of the ranks. |
| |
| Args: |
| tensor (:class:`torch.Tensor`): Tensor needs to be sharded. |
| sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification |
| describing how to shard the Tensor. |
| |
| Keyword args: |
| src_rank (int, optional): The source rank which is used as the ground truth of |
| the data for the parameter that would be sharded and scattered |
| across the rest of the ranks. |
| Default: 0. |
| process_group (ProcessGroup, optional): The process group to work on. If None, |
| the default process group will be used. |
| |
| Returns: |
| A :class:`ShardedTensor` sharded from the given tensor. |
| |
| .. warning:: |
| Only :class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec` is |
| currently supported as the ``sharding_spec``. |
| """ |
| if not tensor.is_contiguous(): |
| raise ValueError('input tensor is not a contiguous Tensor') |
| |
| pg = process_group if process_group is not None else distributed_c10d._get_default_group() |
| world_size = dist.get_world_size(pg) |
| current_rank = dist.get_rank(pg) |
| |
| # Validate src_rank and sharding_spec are same across all ranks. |
| gathered_list = [None] * world_size |
| dist.all_gather_object(gathered_list, (src_rank, sharding_spec), group=pg) |
| |
| for idx, entry in enumerate(gathered_list): |
| if src_rank != entry[0]: # type: ignore[index] |
| raise ValueError( |
| f'src_rank={src_rank} on rank: {current_rank} does not ' # type: ignore[index] |
| f'match with src_rank={entry[0]} on rank: {idx}') |
| if sharding_spec != entry[1]: # type: ignore[index] |
| raise ValueError( |
| f'sharding_spec={sharding_spec} on rank: {current_rank} does not ' # type: ignore[index] |
| f'match with sharding_spec={entry[1]} on rank: {idx}') |
| |
| st = sharding_spec.shard(tensor, src_rank=src_rank, process_group=process_group) |
| |
| return st |
| |
| def shard_parameter( |
| module: torch.nn.Module, |
| param_name: str, |
| sharding_spec: ShardingSpec, |
| src_rank=0, |
| process_group=None): |
| """ |
| Given a :class:`torch.nn.Module`, a ``param_name`` for a parameter in that |
| module, it shards that parameter according to the provided |
| ``sharding_spec``. ``src_rank`` denotes the source rank which would be |
| used as the ground truth of the data which would be scattered as shards |
| across the rest of the ranks. |
| |
| This method replaces ``module.param_name`` with a |
| :class:`torch.distributed._sharded_tensor.ShardedTensor` |
| |
| Args: |
| module (:class:`torch.nn.Module`): Module whose parameter needs to be sharded. |
| param_name (str): Name of the parameter of ``module`` that needs to be sharded. |
| sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification |
| describing how to shard the Tensor. |
| |
| Keyword args: |
| src_rank (int, optional): The source rank which is used as the ground truth of |
| the data for the parameter that would be sharded and scattered |
| across the rest of the ranks. |
| Default: 0. |
| process_group (ProcessGroup, optional): The process group to work on. If None, |
| the default process group will be used. |
| |
| .. warning:: |
| Only :class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec` is |
| currently supported as the ``sharding_spec``. |
| """ |
| # Perform some validation first. |
| if not hasattr(module, param_name): |
| raise ValueError(f'module: {module} does not have parameter with name: {param_name}') |
| |
| tensor = getattr(module, param_name) |
| if not isinstance(tensor, torch.Tensor): |
| raise ValueError(f'Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}') |
| |
| if not tensor.is_contiguous(): |
| raise ValueError(f'param: {param_name} is not a contiguous Tensor') |
| |
| st = _shard_tensor(tensor, sharding_spec, src_rank, process_group) |
| |
| # Replace param with ShardedTensor. |
| |
| # Need to delete the attribute first since param_name might be |
| # torch.nn.Parameter and can't be replaced with ShardedTensor which is |
| # not torch.nn.Parameter. |
| delattr(module, param_name) |
| |
| # Now we can set the attribute appropriately. |
| setattr(module, param_name, st) |
| |
| |
| def _replicate_tensor(tensor: torch.Tensor, process_group=None) -> ReplicatedTensor: |
| """ |
| Given a :class:`torch.Tensor`, mark it as a ReplicatedTensor where all |
| ranks have the same value. |
| |
| Args: |
| tensor (:class:`torch.Tensor`): the tensor to be marked as replicated. |
| Keyword args: |
| process_group (ProcessGroup, optional): The process group to replicate on. |
| If None, the default process group will be used. |
| Returns: |
| A :class:`ReplicatedTensor` from the given tensor. |
| |
| """ |
| return ReplicatedTensor(tensor, process_group=process_group) |