blob: c5082b0d9a89f75cc4f75527b58b148a07085421 [file] [log] [blame]
import torch
import torch.distributed as dist
from torch.distributed import distributed_c10d
from torch.distributed._shard.sharded_tensor import (
ShardedTensor,
)
from .sharding_spec import (
ShardingSpec,
)
from .replicated_tensor import ReplicatedTensor
def _shard_tensor(
tensor: torch.Tensor, sharding_spec: ShardingSpec, src_rank=0, process_group=None
) -> ShardedTensor:
"""
Given a :class:`torch.Tensor`, it shards that tensor according to the provided
``sharding_spec``. ``src_rank`` denotes the source rank which would be
used as the ground truth of the data which would be scattered as shards
across the rest of the ranks.
Args:
tensor (:class:`torch.Tensor`): Tensor needs to be sharded.
sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
describing how to shard the Tensor.
Keyword args:
src_rank (int, optional): The source rank which is used as the ground truth of
the data for the parameter that would be sharded and scattered
across the rest of the ranks.
Default: 0.
process_group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
Returns:
A :class:`ShardedTensor` sharded from the given tensor.
.. warning::
Only :class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec` is
currently supported as the ``sharding_spec``.
"""
if not tensor.is_contiguous():
raise ValueError('input tensor is not a contiguous Tensor')
pg = process_group if process_group is not None else distributed_c10d._get_default_group()
world_size = dist.get_world_size(pg)
current_rank = dist.get_rank(pg)
# Validate src_rank and sharding_spec are same across all ranks.
gathered_list = [None] * world_size
dist.all_gather_object(gathered_list, (src_rank, sharding_spec), group=pg)
for idx, entry in enumerate(gathered_list):
if src_rank != entry[0]: # type: ignore[index]
raise ValueError(
f'src_rank={src_rank} on rank: {current_rank} does not ' # type: ignore[index]
f'match with src_rank={entry[0]} on rank: {idx}')
if sharding_spec != entry[1]: # type: ignore[index]
raise ValueError(
f'sharding_spec={sharding_spec} on rank: {current_rank} does not ' # type: ignore[index]
f'match with sharding_spec={entry[1]} on rank: {idx}')
st = sharding_spec.shard(tensor, src_rank=src_rank, process_group=process_group)
return st
def shard_parameter(
module: torch.nn.Module,
param_name: str,
sharding_spec: ShardingSpec,
src_rank=0,
process_group=None):
"""
Given a :class:`torch.nn.Module`, a ``param_name`` for a parameter in that
module, it shards that parameter according to the provided
``sharding_spec``. ``src_rank`` denotes the source rank which would be
used as the ground truth of the data which would be scattered as shards
across the rest of the ranks.
This method replaces ``module.param_name`` with a
:class:`torch.distributed._sharded_tensor.ShardedTensor`
Args:
module (:class:`torch.nn.Module`): Module whose parameter needs to be sharded.
param_name (str): Name of the parameter of ``module`` that needs to be sharded.
sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`): The specification
describing how to shard the Tensor.
Keyword args:
src_rank (int, optional): The source rank which is used as the ground truth of
the data for the parameter that would be sharded and scattered
across the rest of the ranks.
Default: 0.
process_group (ProcessGroup, optional): The process group to work on. If None,
the default process group will be used.
.. warning::
Only :class:`torch.distributed._shard.sharding_spec.ChunkShardingSpec` is
currently supported as the ``sharding_spec``.
"""
# Perform some validation first.
if not hasattr(module, param_name):
raise ValueError(f'module: {module} does not have parameter with name: {param_name}')
tensor = getattr(module, param_name)
if not isinstance(tensor, torch.Tensor):
raise ValueError(f'Expected {type(module).__name__}.{param_name} to be a Tensor, but found {type(tensor).__name__}')
if not tensor.is_contiguous():
raise ValueError(f'param: {param_name} is not a contiguous Tensor')
st = _shard_tensor(tensor, sharding_spec, src_rank, process_group)
# Replace param with ShardedTensor.
# Need to delete the attribute first since param_name might be
# torch.nn.Parameter and can't be replaced with ShardedTensor which is
# not torch.nn.Parameter.
delattr(module, param_name)
# Now we can set the attribute appropriately.
setattr(module, param_name, st)
def _replicate_tensor(tensor: torch.Tensor, process_group=None) -> ReplicatedTensor:
"""
Given a :class:`torch.Tensor`, mark it as a ReplicatedTensor where all
ranks have the same value.
Args:
tensor (:class:`torch.Tensor`): the tensor to be marked as replicated.
Keyword args:
process_group (ProcessGroup, optional): The process group to replicate on.
If None, the default process group will be used.
Returns:
A :class:`ReplicatedTensor` from the given tensor.
"""
return ReplicatedTensor(tensor, process_group=process_group)