Allow for custom sharding specs to register their own ops.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/76360
Customized ShardingSpecs could be entirely arbitrary and it would not
be possible to handle ops for those as a result since they might not fit into
the patterns supported by the in-built ShardingSpecs. As a result, we introduce
a framework for a ShardingSpec to override ops as follows:
1) In the dispatch system, if a ShardingSpec has a customized op the registered
op for that ShardingSpec is invoked.
2) As a result, all ChunkShardingSpec specific ops have been moved under that
ShardingSpec.
3) There will be a set of ShardingSpec agnostic ops (ex: elementwise ops) which
will be a set of common ops supported across any ShardingSpec.
4) If an op is not found for a particular ShardingSpec the default set of ops
is searched for that op.
Differential Revision: [D35917912](https://our.internmc.facebook.com/intern/diff/D35917912/)
Approved by: https://github.com/wanchaol
diff --git a/test/distributed/_shard/sharded_tensor/ops/test_math_ops.py b/test/distributed/_shard/sharded_tensor/ops/test_math_ops.py
index b6c9930..f3129f0 100644
--- a/test/distributed/_shard/sharded_tensor/ops/test_math_ops.py
+++ b/test/distributed/_shard/sharded_tensor/ops/test_math_ops.py
@@ -129,7 +129,7 @@
st = sharded_tensor.rand(spec, 10, 10)
- with self.assertRaisesRegex(TypeError, 'with ChunkShardingSpec supports'):
+ with self.assertRaisesRegex(RuntimeError, 'not supported'):
torch.add(st, sharded_rhs)
diff --git a/torch/distributed/_shard/sharded_tensor/_ops/__init__.py b/torch/distributed/_shard/sharded_tensor/_ops/__init__.py
index 2545ccd..f59c8d3 100644
--- a/torch/distributed/_shard/sharded_tensor/_ops/__init__.py
+++ b/torch/distributed/_shard/sharded_tensor/_ops/__init__.py
@@ -1,10 +1,13 @@
import torch.distributed._shard.sharded_tensor._ops.chunk
import torch.distributed._shard.sharded_tensor._ops.elementwise_ops
import torch.distributed._shard.sharded_tensor._ops.math_ops
-import torch.distributed._shard.sharded_tensor._ops.matrix_ops
from .binary_cmp import equal, allclose
-from .embedding import sharded_embedding
-from .embedding_bag import sharded_embedding_bag
from .init import kaiming_uniform_, normal_, uniform_, constant_
-from .linear import sharded_linear
+
+# Import all ChunkShardingSpec ops
+from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.linear import sharded_linear
+from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.embedding import sharded_embedding
+from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.embedding_bag import sharded_embedding_bag
+import torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.math_ops
+import torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.matrix_ops
diff --git a/torch/distributed/_shard/sharded_tensor/_ops/_common.py b/torch/distributed/_shard/sharded_tensor/_ops/_common.py
index 3ef25a2..b5dc61f 100644
--- a/torch/distributed/_shard/sharded_tensor/_ops/_common.py
+++ b/torch/distributed/_shard/sharded_tensor/_ops/_common.py
@@ -1,35 +1,9 @@
-# coding=utf-8
-
import functools
-from typing import List
-
-import torch
-import torch.distributed as dist
-import torch.distributed._shard.sharding_spec as shard_spec
from torch.distributed._shard.sharded_tensor import (
sharded_op_impl,
Shard,
ShardedTensor,
)
-from torch.distributed._shard.sharding_spec._internals import (
- get_split_size,
- get_chunked_dim_size,
-)
-from torch.distributed.nn.functional import (
- all_gather,
- all_to_all_single,
-)
-
-
-def _chunk_sharding_spec_check(spec, op):
- """
- For the given op implementation check if the sharding spec is ChunkShardingSpec.
- """
- if not isinstance(spec, shard_spec.ChunkShardingSpec):
- raise NotImplementedError(
- f"Only ChunkShardingSpec supported for '{op.__name__}'."
- )
-
def _sharded_op_common(op, early_stop_func, extra_check):
"""
@@ -84,7 +58,6 @@
return decorator_sharded_func
-
def _register_sharded_op_on_local_shards(
op, early_stop_func=None, extra_check=None, customized_func=None
):
@@ -132,421 +105,3 @@
process_group=pg,
init_rrefs=st._init_rrefs,
)
-
-
-def _register_sharded_op_on_local_tensor(
- op, early_stop_func=None, extra_check=None, customized_func=None
-):
- """
- Handles ``__torch_function__`` dispatch for ops which are performed on
- the single local tensor of the sharded tensor such as op like
- ``torch.nn.functional.softmax`` or ``torch.Tensor.view``.
-
- For more complicated ops, a customized func can be used to generate
- the new local tensor, sharding spec and sharded tensor size.
-
- Args:
- op: The op to be registered and applied to all shards of the st.
- early_stop_func (Callable, optional): the func for early stop.
- Default: if ``None``, no early stop.
- extra_check (Callable, optional): the func for extra condition check.
- Default: if ``None``, no extra check.
- customized_func (Callable, optional): the func for customized logic
- to generate the new local tensor, sharding spec and sharded tensor size.
- Default: if ``None``, we simply lower to the real op call with
- the single local tensor of the st.
-
- Return:
- func (Callable): registered implementation for sharded op for
- ``__torch_function__`` dispatch.
- """
- @sharded_op_impl(op)
- @_sharded_op_common(op, early_stop_func, extra_check)
- def sharded_tensor_op_on_local_shards(types, args=(), kwargs=None, pg=None):
- st = args[0]
- sharding_spec = st.sharding_spec()
- _chunk_sharding_spec_check(sharding_spec, op)
- if len(st.local_shards()) != 1:
- raise TypeError(
- f"torch function '{op.__name__}', with args: {args} and "
- f"kwargs: {kwargs} only supported for single local tensor!"
- )
- st_size = st.size()
- if customized_func:
- local_tensor, sharding_spec, st_size = customized_func(args, kwargs, pg)
- else:
- args = (st.local_tensor(), *args[1:])
- local_tensor = op(*args, **kwargs)
- return ShardedTensor._init_from_local_tensor(
- local_tensor.contiguous(),
- sharding_spec,
- st_size, # type: ignore[arg-type]
- process_group=pg,
- init_rrefs=st._init_rrefs,
- )
-
-
-def _handle_col_wise_sharding_base(
- op_func,
- col_dim,
- input,
- world_size,
- weight,
- local_shard,
- pg,
- gathered_inputs=None,
- mode=None,
- gathered_per_sample_weights=None,
- gathered_offsets=None,
- padding_idx=None,
-):
- """
- For col-wise sharding of weight, lots of logic are common.
- So we extract the common logic and put in this function:
- Step 1. To get input from each rank and
- Step 2. To perform the op on the concatenated tensor.
- Step 3. To distribute results to each rank with col rearrangement.
- Step 4. To concatenate all results from all ranks.
-
- Args:
- op_func: operator which is applied to the input tensor.
- col_dim: dim of result tensor after the operation.
- input: tensor to be applied op on.
- world_size: number of ranks.
- weight: shareded weight tensor.
- local_shard: col-wise sharded weight tensor.
- pg: process group.
- gathered_inputs: list of inputs from all ranks. If specified, we
- don't need to communicate with each rank any more.
- mode: aggregation mode of EmbeddingBag.
- gathered_per_sample_weights: per_sample_weights across all ranks.
- gathered_offsets: offsets across all ranks.
- padding_idx: If specified, the entries at padding_idx do
- not contribute to the gradient; therefore, the embedding
- vector at padding_idx is not updated during training,
- i.e. it remains as a fixed “pad”.
- Note that the embedding vector at padding_idx is
- excluded from the reduction.
-
- Return: final result of input being applied with the op.
- """
- if gathered_inputs is None:
- # allgather the inputs first.
- gathered_inputs = all_gather(input, group=pg)
-
- # run the operator's function for all the inputs.
- results = []
- for i, inp in enumerate(gathered_inputs):
- if op_func == torch.nn.functional.embedding_bag:
- result = op_func(
- inp,
- local_shard,
- offsets=gathered_offsets[i] if gathered_offsets is not None else None,
- mode=mode,
- per_sample_weights=gathered_per_sample_weights[i]
- if gathered_per_sample_weights is not None
- else None,
- padding_idx=padding_idx,
- )
- elif op_func == torch.nn.functional.embedding:
- result = op_func(
- inp,
- local_shard,
- padding_idx=padding_idx,
- )
- else:
- result = op_func(inp, local_shard)
- results.append(torch.transpose(result, 0, col_dim))
-
- # Distribute results to each rank with col rearrangement.
- output = _result_distribute_with_col_rearrange(
- results, input, world_size, weight, pg
- )
-
- # transpose the output and return result.
- return torch.transpose(output, 0, col_dim)
-
-
-def _result_distribute_with_col_rearrange(
- results, input, world_size, weight, pg
-):
- """
- For col-wise sharding of weight, we need to distribute
- results to each rank. We do them in this function.
- Note that, if the index in the Sharding Spec is not equal to
- the rank number, we need to do the rearrangement based on the
- order given by the Sharding Spec (placement).
-
- Args:
- results: results from ops applied to inputs from all ranks.
- We need to distribute them back to their original ranks.
- input: tensor to be applied op to.
- world_size: number of ranks.
- weight: shareded weight tensor.
- pg: process group.
-
- Return: column rearranged result.
- """
- # Process results and outputs for all2all.
- sharding_dim = weight._sharding_spec.dim
- sharding_dim_size = weight.size(sharding_dim)
- dims = list(results[0].size())
- dims[0] = sharding_dim_size
- output = torch.empty(*dims, device=input.device)
- combined_results = torch.cat(results)
-
- # Compute output splits
- split_size = get_split_size(sharding_dim_size, world_size)
- output_split_sizes = [0] * world_size
- for idx, placement in enumerate(weight._sharding_spec.placements):
- output_split_sizes[placement.rank()] = get_chunked_dim_size(
- sharding_dim_size, split_size, idx
- )
-
- # distribute the outputs using all2all.
- output = all_to_all_single(
- output, combined_results, output_split_sizes=output_split_sizes, group=pg
- )
-
- # Check if we need to rearrange columns appropriately for output.
- rearrange_columns = any(
- [
- idx != placement.rank()
- for idx, placement in enumerate(weight._sharding_spec.placements)
- ]
- )
- if not rearrange_columns:
- return output
-
- indices = []
- for placement in weight._sharding_spec.placements:
- dim_size = output_split_sizes[placement.rank()]
- start = sum(
- [
- split_size if i < placement.rank() else 0
- for i, split_size in enumerate(output_split_sizes)
- ]
- )
- indices += list(range(start, start + dim_size))
-
- return output.index_select(0, torch.tensor(indices, device=output.device))
-
-
-def _handle_row_wise_lookup_distribute(
- input_sorted, input, world_size, weight, rank, padding_idx
-):
- """
- In the circumstance of row-wise sharding of weight, we need to distribute
- the sorted lookup IDs of embedding/embeddingBag to each rank.
- If the index in the placement is not equal to the rank number, we need to
- do the rearrangement based on the order given by the Sharding Spec (placement).
-
- In addition, we do two things for padding_idx. The first thing is to only
- set it if it's within the range of the current rank and the other thing
- is to do the modularization of it by sharded_dim_size_max.
-
- Args:
- input_sorted: sorted lookup IDs of embedding/embeddingBag.
- input: tensor to be applied op on.
- world_size: number of ranks.
- weight: shareded weight tensor.
- rank: # of cuda process.
- padding_idx: If specified, the entries at padding_idx do
- not contribute to the gradient and reduction.
-
- Return:
- input_sorted: sorted lookup IDs of embedding/embeddingBag
- Rearrangement performed if it is needed.
- input_split_sizes: size of IDs to be assigned to each rank.
- sharded_dim_size_max: the max size of the row each rank gets.
- input_split_rearrange_indices: indices of row rearrangement.
- rearrange_indices_1d_second_order: reverse indices of row
- rearrangement, which will be used to restore the original
- order.
- padding_idx: Same as input if padding_idx is within the range
- of the given rank; otherwise, None is returned. It is
- also modularized by sharded_dim_size_max.
- """
- # Decide which rank the input goes to by check the sharding range.
- split_size = get_split_size(weight.size(0), world_size)
- rearrange_rows = False
- indices_flatten = None
- input_split_sizes: List[int] = [0] * world_size
- input_split_start_indices: List[int] = [0] * world_size
- start_row_idx_rank = None
- end_row_idx_rank = None
- # When we do the chunk split, we always ensure the first N - 1 chunks get max out
- # and then the Nth chunk gets the rest. So input_split_sizes like [3, 3, 3, 4]
- # are not possible. The expected split size will be [4, 4, 4, 1].
- sharded_dim_size_max = get_chunked_dim_size(weight.size(0), split_size, 0)
- for idx, placement in enumerate(weight._sharding_spec.placements):
- sharded_dim_size = get_chunked_dim_size(weight.size(0), split_size, idx)
- start_row_idx = idx * sharded_dim_size_max
- end_row_idx = start_row_idx + sharded_dim_size
- start_idx = torch.searchsorted(input_sorted, start_row_idx).item()
- end_idx = torch.searchsorted(input_sorted, end_row_idx).item()
- input_split_sizes[placement.rank()] = int(end_idx - start_idx)
- input_split_start_indices[placement.rank()] = int(start_idx)
- if placement.rank() != idx:
- rearrange_rows = True
- # Store the range of the current rank.
- if placement.rank() == rank:
- start_row_idx_rank = start_row_idx
- end_row_idx_rank = end_row_idx
-
- # Perform the modular if padding_idx is within the range.
- if padding_idx is not None:
- if padding_idx < start_row_idx_rank or padding_idx >= end_row_idx_rank:
- padding_idx = None
- else:
- padding_idx = padding_idx % sharded_dim_size_max
-
- rearrange_indices_1d_second_order = None
- if rearrange_rows:
- # Need to re-arrange the 1D tensor to be sent via all2all.
- indices: List[List[int]] = [[0]] * world_size
- for placement in weight._sharding_spec.placements:
- split_length = input_split_sizes[placement.rank()]
- offset_idx = input_split_start_indices[placement.rank()]
- indices[placement.rank()] = list(
- range(offset_idx, offset_idx + split_length)
- )
- indices_flatten = list(idx for indice in indices for idx in indice)
-
- input_sorted = input_sorted.index_select(
- 0, torch.tensor(indices_flatten, device=input.device)
- )
- rearrange_indices_1d_second_order = torch.argsort(torch.Tensor(indices_flatten))
-
- return (
- input_sorted,
- input_split_sizes,
- sharded_dim_size_max,
- torch.tensor(indices_flatten, device=input.device) if rearrange_rows else None,
- rearrange_indices_1d_second_order,
- padding_idx,
- )
-
-
-def _communicate_size_to_each_rank(
- input_size_list, output_size, input, pg, tensor_type=torch.int
-):
- """
- In the circumstance of row-wise sharding of weight, we need to first
- communicate the input length to each rank because each rank gets a
- different one.
-
- Args:
- input_size_list: list of sizes to be sent to each rank.
- output_size: length of the output tensor.
- input: tensor to be applied op on.
- pg: process group.
- tensor_type: dtype of tensor.
-
- Return: A list of communication results (int).
- """
- input_size_list_tensor = torch.tensor(
- input_size_list, dtype=tensor_type, device=input.device
- )
- output_size_list_tensor = torch.empty(
- output_size, dtype=tensor_type, device=input.device
- )
- dist.all_to_all_single(
- output_size_list_tensor,
- input_size_list_tensor,
- group=pg,
- )
- return output_size_list_tensor.tolist()
-
-
-def _communicate_list_to_each_rank(
- input_tensor_list, output_lists, input, pg, tensor_type=torch.int64
-):
- """
- In the circumstance of row-wise sharding of weight, we need to
- communicate a list of input tensors to each rank. Because the
- input could be a list of list, we need to first convert the list
- to a tensor.
-
- Args:
- input_tensor_list: list of tensors to be sent to each rank.
- output_lists: list of sizes to be obtained from each rank.
- input: tensor to be applied op on.
- pg: process group.
- tensor_type: dtype of tensor.
-
- Return: A list of communication results (tensors).
- """
- output_tensor_list = []
- for output_list in output_lists:
- output_tensor_list.append(
- torch.empty(output_list, dtype=tensor_type, device=input.device)
- )
- dist.all_to_all(
- output_tensor_list,
- input_tensor_list,
- group=pg,
- )
- return output_tensor_list
-
-
-def _handle_max_norm_col_wise(
- max_norm,
- norm_type,
- local_shard,
- input,
- world_size,
- pg,
-):
- """
- For col-wise sharding of weight, we need to aggregate the
- norm across all ranks before we can perform the proper re-norm.
- Note that, the max_norm logic is only applied to the embedding
- indices that are looked up and not the whole shard.
-
- Args:
- max_norm: If given, each embedding vector with norm larger
- than max_norm is renormalized to have norm max_norm.
- Note: this will modify weight in-place.
- norm_type: The p in the p-norm to compute for the max_norm option.
- local_shard: col-wise shared local weight used for lookup.
- input: tensor to be applied op to.
- world_size: number of ranks.
- pg: process group.
-
- Return:
- local_shard_norm_renormed: local_shard re-normed to max_norm if the norm is larger
- than it.
- gathered_inputs: list of inputs from all ranks.
- """
- norm_type = norm_type if norm_type is not None else 2.0
- # allgather the inputs first.
- gathered_inputs = [torch.zeros_like(input) for _ in range(world_size)]
- dist.all_gather(gathered_inputs, input, group=pg)
- unique_inp = torch.unique(torch.cat(gathered_inputs))
- local_shard_sum = torch.sum(
- torch.pow(torch.abs(local_shard), norm_type), dim=1, dtype=local_shard.dtype
- )
- # For col-wise sharding, we need to first aggregate the powered sum
- # from each rank first and then calculate the norm.
- dist.all_reduce(local_shard_sum, group=pg)
- local_shard_norm = torch.pow(local_shard_sum, 1.0 / norm_type)
- max_norm_tensor = torch.full(
- (local_shard.size(0),),
- float("inf"),
- dtype=local_shard.dtype,
- device=input.device,
- )
- max_norm_tensor[unique_inp] = max_norm
- local_shard_t = local_shard.t().contiguous()
- normalized_tensor = torch.where(
- local_shard_norm > max_norm_tensor, max_norm_tensor, local_shard_norm
- )
- # Make sure divisor is not zero.
- local_shard_norm[local_shard_norm == 0.0] = 1.0
- local_shard_norm_renormed = (
- torch.div(torch.mul(local_shard_t, normalized_tensor), local_shard_norm)
- .t()
- .contiguous()
- )
- return local_shard_norm_renormed, gathered_inputs
diff --git a/torch/distributed/_shard/sharded_tensor/_ops/default_tensor_ops.py b/torch/distributed/_shard/sharded_tensor/_ops/default_tensor_ops.py
index 2b8c2a4..8c23f1d 100644
--- a/torch/distributed/_shard/sharded_tensor/_ops/default_tensor_ops.py
+++ b/torch/distributed/_shard/sharded_tensor/_ops/default_tensor_ops.py
@@ -1,6 +1,12 @@
+import copy
import torch
from torch.distributed._shard.sharded_tensor import (
sharded_op_impl,
+ Shard,
+ ShardedTensor,
+)
+from ._common import (
+ _register_sharded_op_on_local_shards,
)
@@ -33,3 +39,58 @@
# __reduce_ex__ to dispatch to get_state/set_state
register_default_op(torch.Tensor.__reduce_ex__)
+
+def sharded_type_as_check(*args, **kwargs):
+ """
+ Perform extra checks for the sharded_type_as op such as the input needs to
+ be either a Tensor or ShardedTensor.
+
+ Args: same as ``torch.Tensor.type_as``.
+
+ Return: None
+ """
+ if len(args) < 2:
+ raise ValueError("Needs to give a tensor to cast type as!")
+ if not isinstance(args[1], torch.Tensor) and not isinstance(args[1], ShardedTensor):
+ raise ValueError("Needs to give a Tensor or ShardedTensor to cast type as!")
+
+
+def same_dtype(*args, **kwargs):
+ """
+ When the dtype is the same, return the original ShardedTensor.
+
+ Args: same as ``torch.Tensor.type_as``.
+
+ Return (bool): Whether to return early or not.
+ """
+ return args[0].dtype == args[1].dtype
+
+
+def sharded_type_as(args, kwargs, pg):
+ """
+ Handles ``__torch_function__`` dispatch for the ``torch.Tensor.type_as`` op.
+
+ Args: same as ``torch.Tensor.type_as``.
+
+ Return:
+ new_local_shards (List[Shard]): Local shards for the new sharded tensor.
+ st_meta (ShardedTensorMetadata): Metadata of the new sharded tensor.
+ """
+ st = args[0]
+ tensor = args[1]
+ if isinstance(tensor, ShardedTensor):
+ tensor = tensor.local_tensor()
+ new_local_shards = []
+ for shard in st.local_shards():
+ new_local_shards.append(Shard(shard.tensor.type_as(tensor), shard.metadata))
+ st_meta = copy.deepcopy(st._metadata)
+ st_meta.tensor_properties.dtype = tensor.dtype
+ return new_local_shards, st_meta
+
+
+_register_sharded_op_on_local_shards(
+ torch.Tensor.type_as,
+ early_stop_func=same_dtype,
+ extra_check=sharded_type_as_check,
+ customized_func=sharded_type_as,
+)
diff --git a/torch/distributed/_shard/sharded_tensor/_ops/elementwise_ops.py b/torch/distributed/_shard/sharded_tensor/_ops/elementwise_ops.py
index a53eebc..dc65277 100644
--- a/torch/distributed/_shard/sharded_tensor/_ops/elementwise_ops.py
+++ b/torch/distributed/_shard/sharded_tensor/_ops/elementwise_ops.py
@@ -4,7 +4,6 @@
_register_sharded_op_on_local_shards,
)
-
_register_sharded_op_on_local_shards(torch.nn.functional.gelu)
_register_sharded_op_on_local_shards(torch.nn.functional.relu)
_register_sharded_op_on_local_shards(torch.nn.functional.dropout)
diff --git a/torch/distributed/_shard/sharded_tensor/_ops/math_ops.py b/torch/distributed/_shard/sharded_tensor/_ops/math_ops.py
index 205f016..6d3ed59 100644
--- a/torch/distributed/_shard/sharded_tensor/_ops/math_ops.py
+++ b/torch/distributed/_shard/sharded_tensor/_ops/math_ops.py
@@ -4,102 +4,80 @@
ShardedTensor,
sharded_op_impl
)
-from torch.distributed._shard.sharding_spec import ChunkShardingSpec
from torch.distributed._shard.replicated_tensor import ReplicatedTensor
from torch.distributed._shard._utils import narrow_tensor
-from ._common import (
- _chunk_sharding_spec_check,
- _register_sharded_op_on_local_tensor,
-)
+def binary_math_op_impl(op, types, args=(), kwargs=None, pg=None):
+ """
+ Handles ``__torch_function__`` dispatch for the binary math ops
+ such as `torch.add`, `torch.mul`, `torch.div`, etc.
+ This method computes on ShardedTensor, or ShardedTensor op ReplicatedTensor
+ """
+ if len(args) != 2:
+ raise ValueError("Only support binary math op on ShardedTensor for now!")
+ lhs = args[0]
+ rhs = args[1]
+ # Validate types
+ if isinstance(lhs, ReplicatedTensor):
+ assert isinstance(rhs, ShardedTensor)
+ st_size = rhs.size()
+ st_meta = rhs.local_shards()[0].metadata
+ if st_size != lhs.size():
+ # try to broadcast replicated tensor
+ lhs = lhs.expand(st_size)
+ replica_part = narrow_tensor(lhs, st_meta)
+ res = op(replica_part, rhs.local_tensor())
+
+ return ShardedTensor._init_from_local_tensor(
+ res,
+ rhs.sharding_spec(),
+ rhs.size(), # type: ignore[arg-type]
+ process_group=pg)
+
+ elif isinstance(rhs, ReplicatedTensor):
+ assert isinstance(lhs, ShardedTensor)
+ st_size = lhs.size()
+ st_meta = lhs.local_shards()[0].metadata
+ if st_size != rhs.size():
+ # try to broadcast replicated tensor
+ rhs = rhs.expand(st_size)
+
+ replica_part = narrow_tensor(rhs, st_meta)
+ res = op(lhs.local_tensor(), replica_part)
+ return ShardedTensor._init_from_local_tensor(
+ res,
+ lhs.sharding_spec(),
+ lhs.size(), # type: ignore[arg-type]
+ process_group=pg)
+
+ elif isinstance(lhs, (int, float)):
+ assert isinstance(rhs, ShardedTensor)
+ res = op(lhs, rhs.local_tensor())
+ return ShardedTensor._init_from_local_tensor(
+ res,
+ rhs.sharding_spec(),
+ rhs.size(), # type: ignore[arg-type]
+ process_group=pg)
+
+ elif isinstance(rhs, (int, float)):
+ assert isinstance(lhs, ShardedTensor)
+ res = op(lhs.local_tensor(), rhs)
+ return ShardedTensor._init_from_local_tensor(
+ res,
+ lhs.sharding_spec(),
+ lhs.size(), # type: ignore[arg-type]
+ process_group=pg)
+ else:
+ raise RuntimeError(
+ f"torch function '{op.__name__}', with args: {args} and "
+ f"kwargs: {kwargs} not supported yet for ShardedTensor!")
def register_math_op(op):
@sharded_op_impl(op)
def binary_math_op(types, args=(), kwargs=None, pg=None):
- """
- Handles ``__torch_function__`` dispatch for the binary math ops
- such as `torch.add`, `torch.mul`, `torch.div`, etc.
- This method computes on ShardedTensor, or ShardedTensor op ReplicatedTensor
- """
- if len(args) != 2:
- raise ValueError("Only support binary math op on ShardedTensor for now!")
- lhs = args[0]
- rhs = args[1]
- # Validate types
- if isinstance(lhs, ShardedTensor) and isinstance(rhs, ShardedTensor):
- lhs_spec = lhs.sharding_spec()
- rhs_spec = rhs.sharding_spec()
- if not isinstance(lhs_spec, ChunkShardingSpec) or not isinstance(rhs_spec, ChunkShardingSpec):
- raise TypeError("Only ShardedTensor with ChunkShardingSpec supports"
- " two ShardedTensor together")
-
- if lhs.size() == rhs.size() and lhs_spec.dim == rhs_spec.dim:
- # perform local element-wise math op
- res = op(lhs.local_tensor(), rhs.local_tensor())
- return ShardedTensor._init_from_local_tensor(
- res,
- lhs_spec,
- lhs.size(), # type: ignore[arg-type]
- process_group=pg)
- else:
- raise RuntimeError("Implicit broadcasting not supported yet!")
-
- elif isinstance(lhs, ReplicatedTensor):
- assert isinstance(rhs, ShardedTensor)
- st_size = rhs.size()
- st_meta = rhs.local_shards()[0].metadata
- if st_size != lhs.size():
- # try to broadcast replicated tensor
- lhs = lhs.expand(st_size)
-
- replica_part = narrow_tensor(lhs, st_meta)
- res = op(replica_part, rhs.local_tensor())
-
- return ShardedTensor._init_from_local_tensor(
- res,
- rhs.sharding_spec(),
- rhs.size(), # type: ignore[arg-type]
- process_group=pg)
-
- elif isinstance(rhs, ReplicatedTensor):
- assert isinstance(lhs, ShardedTensor)
- st_size = lhs.size()
- st_meta = lhs.local_shards()[0].metadata
- if st_size != rhs.size():
- # try to broadcast replicated tensor
- rhs = rhs.expand(st_size)
-
- replica_part = narrow_tensor(rhs, st_meta)
- res = op(lhs.local_tensor(), replica_part)
- return ShardedTensor._init_from_local_tensor(
- res,
- lhs.sharding_spec(),
- lhs.size(), # type: ignore[arg-type]
- process_group=pg)
-
- elif isinstance(lhs, (int, float)):
- assert isinstance(rhs, ShardedTensor)
- res = op(lhs, rhs.local_tensor())
- return ShardedTensor._init_from_local_tensor(
- res,
- rhs.sharding_spec(),
- rhs.size(), # type: ignore[arg-type]
- process_group=pg)
-
- elif isinstance(rhs, (int, float)):
- assert isinstance(lhs, ShardedTensor)
- res = op(lhs.local_tensor(), rhs)
- return ShardedTensor._init_from_local_tensor(
- res,
- lhs.sharding_spec(),
- lhs.size(), # type: ignore[arg-type]
- process_group=pg)
- else:
- raise RuntimeError(
- f"torch function '{op.__name__}', with args: {args} and "
- f"kwargs: {kwargs} not supported yet for ShardedTensor!")
+ return binary_math_op_impl(op, types, args, kwargs, pg)
binary_ops = [
# add
@@ -126,70 +104,3 @@
for op in binary_ops:
register_math_op(op)
-
-
-def sharded_bmm_check(*args, **kwargs):
- """
- Perform extra checks for the sharded_bmm op, for example, st2 needs to
- be a sharded tensor and both tensors need to sharded by dim 0, etc.
-
- Args: same as ``torch.bmm``.
-
- Return: None
- """
- if len(args) < 2:
- raise TypeError("Needs two tensors to perform torch.bmm.")
- st = args[0]
- st2 = args[1]
- # Validate types
- if not isinstance(st2, ShardedTensor):
- raise TypeError("st2 needs to be a ShardedTensor for torch.bmm.")
- _chunk_sharding_spec_check(st2.sharding_spec(), torch.bmm)
- if st.dim() != 3 or st2.dim() != 3:
- raise TypeError("both st and st2 need to be a 3D ShardedTensor")
- if (
- st.sharding_spec().dim != st2.sharding_spec().dim # type: ignore[attr-defined]
- or st.sharding_spec().dim != 0
- ):
- raise NotImplementedError(
- "Only support performing bmm on tensors sharded on dim 0 now."
- )
- if st.sharding_spec().placements != st2.sharding_spec().placements: # type: ignore[attr-defined]
- raise NotImplementedError(
- "Both st and st2 need to have same placements for bmm."
- )
-
-
-def sharded_bmm(args, kwargs, pg):
- """
- Handles ``__torch_function__`` dispatch for the sharded_bmm op.
-
- Warning: For now we only supports the case when both tensors are sharded
- by dim 0 so that no communication is needed.
-
- Args: same as ``torch.bmm``.
-
- Return:
- local_tensor (Tensor): New local tensor to build the sharded tensor.
- sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`):
- sharding spec of the new sharded tensor.
- new_st_size (torch.Size): Size of the new sharded tensor.
- """
- st = args[0]
- st2 = args[1]
- local_tensor = torch.bmm(st.local_tensor(), st2.local_tensor())
- new_st_size = (*st.size()[:-1], st2.size(-1))
- return local_tensor, st.sharding_spec(), new_st_size
-
-
-_register_sharded_op_on_local_tensor(
- torch.Tensor.bmm,
- extra_check=sharded_bmm_check,
- customized_func=sharded_bmm,
-)
-
-_register_sharded_op_on_local_tensor(
- torch.bmm,
- extra_check=sharded_bmm_check,
- customized_func=sharded_bmm,
-)
diff --git a/torch/distributed/_shard/sharded_tensor/api.py b/torch/distributed/_shard/sharded_tensor/api.py
index e66526a..283dfb7 100644
--- a/torch/distributed/_shard/sharded_tensor/api.py
+++ b/torch/distributed/_shard/sharded_tensor/api.py
@@ -17,6 +17,10 @@
from torch.distributed import rpc
from torch.distributed import distributed_c10d
import torch.distributed._shard.sharding_spec as shard_spec
+from torch.distributed._shard.sharding_spec.api import (
+ _dispatch_custom_op,
+ _has_custom_op,
+)
from torch.distributed._shard.sharding_spec._internals import (
check_tensor,
validate_non_overlapping_shards_metadata,
@@ -39,7 +43,7 @@
_sharded_tensor_map: Dict[int, 'weakref.ReferenceType[ShardedTensor]'] = {}
# Custom sharded ops
-_SHARDED_OPS: Dict[str, Callable] = {}
+_SHARDED_OPS: Dict[Callable, Callable] = {}
def _register_sharded_op(op, func):
from inspect import signature
if len(signature(func).parameters) != 4:
@@ -524,6 +528,7 @@
sharded_tensor_metadata,
process_group=process_group,
init_rrefs=init_rrefs,
+ sharding_spec=sharding_spec,
)
@classmethod
@@ -533,6 +538,7 @@
sharded_tensor_metadata: ShardedTensorMetadata,
process_group=None,
init_rrefs=False,
+ sharding_spec=None,
) -> "ShardedTensor":
"""
Initialize a ShardedTensor with local shards and a global
@@ -615,7 +621,10 @@
# done validation, add local_shards
sharded_tensor._local_shards = local_shards
- sharded_tensor._sharding_spec = shard_spec._infer_sharding_spec_from_shards_metadata(shards_metadata)
+ if sharding_spec is None:
+ sharded_tensor._sharding_spec = shard_spec._infer_sharding_spec_from_shards_metadata(shards_metadata)
+ else:
+ sharded_tensor._sharding_spec = sharding_spec
# run post initialization, i.e. map registration, rpc initialization
sharded_tensor._post_init()
@@ -741,15 +750,26 @@
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
- if func in _SHARDED_OPS:
- # Find ShardedTensor instance to get process_group.
- for arg in args:
- if isinstance(arg, ShardedTensor):
- return _SHARDED_OPS[func](types, args, kwargs, arg._process_group)
+ def dispatch(st: ShardedTensor, func: Callable):
+ # Dispatch to custom sharding spec op if it has one.
+ if _has_custom_op(st._sharding_spec, func):
+ return _dispatch_custom_op(st._sharding_spec, func, types, args, kwargs)
- for kwarg in kwargs.values():
- if isinstance(kwarg, ShardedTensor):
- return _SHARDED_OPS[func](types, args, kwargs, kwarg._process_group)
+ if func in _SHARDED_OPS:
+ return _SHARDED_OPS[func](types, args, kwargs, st._process_group)
+
+ raise RuntimeError(
+ f"torch function '{func.__name__}', with args: {args} and "
+ f"kwargs: {kwargs} not supported for ShardedTensor!")
+
+ # Find ShardedTensor instance to get process_group and sharding_spec.
+ for arg in args:
+ if isinstance(arg, ShardedTensor):
+ return dispatch(arg, func)
+
+ for kwarg in kwargs.values():
+ if isinstance(kwarg, ShardedTensor):
+ return dispatch(kwarg, func)
raise RuntimeError(
f"torch function '{func.__name__}', with args: {args} and "
@@ -994,6 +1014,12 @@
def __rtruediv__(self, other):
return handle_torch_function(torch.Tensor.__rdiv__, (self, other), self, other)
+ def tanh(self):
+ return handle_torch_function(torch.Tensor.tanh, (self,), self)
+
+ def __getitem__(self, key):
+ return handle_torch_function(torch.Tensor.__getitem__, (self, key), self, key)
+
@dataclass
class ProcessGroupState:
"""
diff --git a/torch/distributed/_shard/sharding_spec/__init__.py b/torch/distributed/_shard/sharding_spec/__init__.py
index 6f2e5c7..e356295 100644
--- a/torch/distributed/_shard/sharding_spec/__init__.py
+++ b/torch/distributed/_shard/sharding_spec/__init__.py
@@ -1,10 +1,12 @@
from .api import (
- ChunkShardingSpec,
DevicePlacementSpec,
EnumerableShardingSpec,
PlacementSpec,
ShardingSpec,
_infer_sharding_spec_from_shards_metadata,
)
+from .chunk_sharding_spec import (
+ ChunkShardingSpec,
+)
from torch.distributed._shard.metadata import ShardMetadata
diff --git a/torch/distributed/_shard/sharding_spec/api.py b/torch/distributed/_shard/sharding_spec/api.py
index a42add7..32c07cd 100644
--- a/torch/distributed/_shard/sharding_spec/api.py
+++ b/torch/distributed/_shard/sharding_spec/api.py
@@ -1,7 +1,7 @@
from abc import ABC, abstractmethod
from dataclasses import dataclass
-from typing import List, Union
-from typing import TYPE_CHECKING
+import functools
+from typing import Callable, Dict, List, Type, TYPE_CHECKING
import torch
@@ -13,14 +13,7 @@
)
from torch.distributed._shard.metadata import ShardMetadata
-from torch.distributed._shard.sharded_tensor.utils import (
- _parse_and_validate_remote_device
-)
-
-import torch.distributed as dist
import torch.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta
-from torch.distributed._shard.sharded_tensor.shard import Shard
-from torch.distributed._shard._utils import narrow_tensor
if TYPE_CHECKING:
# Only include ShardedTensor when do type checking, exclude it
@@ -51,7 +44,7 @@
if not isinstance(self.device, torch.distributed._remote_device):
self.device = torch.distributed._remote_device(self.device)
-class ShardingSpec(object):
+class ShardingSpec(ABC):
"""
Base class representing sharding specifications.
"""
@@ -92,174 +85,63 @@
A :class:`ShardedTensor` sharded from the given tensor.
"""
-@dataclass
-class ChunkShardingSpec(ShardingSpec):
+# Ops customized for a particular ShardingSpec.
+CUSTOM_SHARDING_SPEC_OPS: Dict[str, Dict[Callable, Callable]] = {}
+
+def _register_custom_op(sharding_spec_cls: Type, op: Callable, func: Callable):
"""
- This is a type of PlacementSpec that defines the placement as being sharded
- across multiple devices. In particular, it represents sharding a Tensor
- along a single dimension into equal chunks (similar to :meth:`torch.chunk`).
-
- The semantics of how a tensor is partitioned is inline with
- :meth:`torch.chunk`, where ``dim`` in torch.chunk corresponds to the
- specified ``dim`` and ``chunks`` in torch.chunk is the number of elements
- in the placement specified.
-
+ Allows registration of a custom op for ShardedTensor to enable
+ custom optimizations for a particular ShardingSpec.
Args:
- dim (int or str):
- The dimension to shard on, could be an integer representing the
- dimension or a string in case of named tensors where dimensions are
- named. Note that named tensor support is not added yet.
- placement(List[Union[_remote_device, str]]):
- Specifies the placement of each shard of the Tensor. The size of
- the list represents the number of shards to be created. This could
- be a list of
- :class:`torch.distributed._remote_device`'s. This list
- could also contain a string which represents remote
- device as accepted by
- :class:`torch.distributed._remote_device`
+ sharding_spec(type): The ShardingSpec for which we need to add this custom op.
+ op(Callable): The op to override (ex: torch.bmm)
+ func(Callable): The custom implementation for ``op``
"""
+ from inspect import signature
+ if len(signature(func).parameters) != 3:
+ raise TypeError(
+ f'Custom sharded op function expects signature: '
+ f'(types, args, kwargs), but received '
+ f'signature: {signature(func)}')
- ShardingDim = Union[int, str]
+ global CUSTOM_SHARDING_SPEC_OPS
+ class_name = sharding_spec_cls.__qualname__
+ if class_name not in CUSTOM_SHARDING_SPEC_OPS:
+ CUSTOM_SHARDING_SPEC_OPS[class_name] = {}
+ CUSTOM_SHARDING_SPEC_OPS[class_name][op] = func
- dim: ShardingDim
- placements: List[Union[torch.distributed._remote_device, str]]
+def _has_custom_op(sharding_spec, op):
+ """
+ Returns whether or not the ShardingSpec has a custom op implementation.
+ """
+ class_name = type(sharding_spec).__qualname__
+ return class_name in CUSTOM_SHARDING_SPEC_OPS and op in CUSTOM_SHARDING_SPEC_OPS[class_name]
- def __post_init__(self):
- self._verify_dim(self.dim)
- for i, remote_device in enumerate(self.placements):
- if not isinstance(remote_device, torch.distributed._remote_device):
- self.placements[i] = torch.distributed._remote_device(remote_device)
+def _dispatch_custom_op(sharding_spec, op: Callable, types, args, kwargs):
+ """
+ Calls the custom op for this ShardingSpec if it exists.
+ """
+ class_name = type(sharding_spec).__qualname__
+ if not _has_custom_op(sharding_spec, op):
+ raise RuntimeError(f'Custom op: {op} not registered for {class_name}')
+ func = CUSTOM_SHARDING_SPEC_OPS[class_name][op]
+ return func(types, args, kwargs)
- @staticmethod
- def _verify_dim(dim):
- # Validate the sharding spec.
- # TODO: support named dimension
- if isinstance(dim, str):
- raise NotImplementedError(
- "ChunkShardingSpec does not support named dimension yet!"
- )
+def custom_sharding_spec_op(sharding_spec_class, func):
+ """
+ Decorator to allow custom registration of ops.
+ Args:
+ sharding_spec_class(type): The ShardingSpec for which we need to add this custom op.
+ func(Callable): The op to override (ex: torch.bmm)
+ """
+ def decorator_sharded_func(wrapped_func):
+ _register_custom_op(sharding_spec_class, func, wrapped_func)
- if not isinstance(dim, int):
- raise ValueError(
- f"Sharding dim needs to be an integer, found: {dim}"
- )
-
- def build_metadata(self,
- tensor_sizes: torch.Size,
- tensor_properties: sharded_tensor_meta.TensorProperties,
- ) -> sharded_tensor_meta.ShardedTensorMetadata:
- tensor_num_dim = len(tensor_sizes)
-
- self._verify_dim(self.dim)
- if self.dim >= tensor_num_dim or self.dim < -tensor_num_dim: # type: ignore[operator]
- raise ValueError(f"Invalid sharding dim: {self.dim}")
-
- shards_metadata = []
- sharding_dim_size = tensor_sizes[self.dim] # type: ignore[index]
- chunks = len(self.placements)
- split_size = get_split_size(sharding_dim_size, chunks)
- for idx, placement in enumerate(self.placements):
- # generate ShardMetadata for each placement device
- chunked_dim_size = get_chunked_dim_size(sharding_dim_size, split_size, idx)
- if chunked_dim_size > 0:
- shard_size = list(tensor_sizes)
- current_offsets = [0] * tensor_num_dim
- current_offsets[self.dim] = split_size * idx # type: ignore[index]
- shard_size[self.dim] = chunked_dim_size # type: ignore[index]
-
- shard_metadata = ShardMetadata(
- shard_offsets=current_offsets,
- shard_sizes=shard_size,
- placement=placement,
- )
- shards_metadata.append(shard_metadata)
-
- return sharded_tensor_meta.ShardedTensorMetadata(
- shards_metadata,
- tensor_sizes,
- tensor_properties
- )
-
-
- def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> "ShardedTensor":
- # relative imports to avoid circular dependency
- from torch.distributed._shard.sharded_tensor import (
- ShardedTensor
- )
- tensor_properties = sharded_tensor_meta.TensorProperties(
- dtype=tensor.dtype,
- layout=tensor.layout,
- requires_grad=tensor.requires_grad,
- memory_format=torch.contiguous_format,
- pin_memory=tensor.is_pinned()
- )
- current_rank = dist.get_rank(process_group)
- tensor_meta = self.build_metadata(tensor.size(), tensor_properties)
- local_shards = []
- local_tensor = None
- local_metadata = None
- tensors_to_scatter = [None] * dist.get_world_size(process_group)
-
- sharding_dim_size = tensor.size()[self.dim] # type: ignore[index]
- chunks = len(self.placements)
- split_size = get_split_size(sharding_dim_size, chunks)
- scatter_shape = list(tensor.size())
- scatter_shape[self.dim] = split_size # type: ignore[index]
-
- for shard_meta in tensor_meta.shards_metadata:
- rank, device = _parse_and_validate_remote_device(process_group, shard_meta.placement)
- if current_rank == src_rank:
- # Reshape to get shard for this rank and we don't want autograd
- # recording here for the narrow op and 'local_shard' should be a
- # leaf variable in the autograd graph.
- narrowed_tensor = narrow_tensor(tensor, shard_meta)
- if shard_meta.shard_sizes[self.dim] < split_size: # type: ignore[index]
- # for the last shard that might be smaller to other shards
- # resize the narrowed tensor to the same size and use it for
- # the scatter collective as dist.scatter requires same size
- # inputs on every rank
- tensor_to_scatter = narrowed_tensor.detach().clone().resize_(scatter_shape)
- else:
- tensor_to_scatter = narrowed_tensor.detach().clone().contiguous()
-
- tensors_to_scatter[rank] = tensor_to_scatter
-
- if current_rank == rank:
- local_tensor = torch.empty(
- scatter_shape, dtype=tensor.dtype, layout=tensor.layout, device=device)
- local_metadata = shard_meta
-
- # each rank should have local_tensor and local_metadata initialized if we build
- # the metadata list in a correct way.
- assert local_tensor is not None
- assert local_metadata is not None
-
- # Scatter the shards to all ranks in the pg
- dist.scatter(
- local_tensor,
- scatter_list=tensors_to_scatter if current_rank == src_rank else None,
- src=src_rank,
- group=process_group
- )
-
- if list(local_tensor.size()) != local_metadata.shard_sizes:
- # detach again after receiving to ensure local shards remain a leaf node
- local_tensor = local_tensor.resize_(local_metadata.shard_sizes).detach()
-
- # Sync requires_grad to local_shard.
- local_tensor.requires_grad = tensor.requires_grad
-
- local_shards.append(Shard(tensor=local_tensor, metadata=local_metadata))
-
- st = ShardedTensor._init_from_local_shards_and_global_metadata(
- local_shards,
- tensor_meta,
- process_group=process_group)
-
- # Manually set sharding_spec
- st._sharding_spec = self
-
- return st
+ @functools.wraps(wrapped_func)
+ def wrapper(*args, **kwargs):
+ return wrapped_func(*args, **kwargs)
+ return wrapper
+ return decorator_sharded_func
@dataclass
@@ -353,6 +235,8 @@
placements = [
x for _, x in sorted(zip(chunk_offset_list, placements), key=lambda e: e[0])
]
+
+ from .chunk_sharding_spec import ChunkShardingSpec
chunk_spec = ChunkShardingSpec(
dim=chunk_sharding_dim,
placements=placements,
diff --git a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py
new file mode 100644
index 0000000..479eea2
--- /dev/null
+++ b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py
@@ -0,0 +1,193 @@
+from dataclasses import dataclass
+import torch
+import torch.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta
+from torch.distributed._shard.metadata import ShardMetadata
+from torch.distributed._shard.sharded_tensor.shard import Shard
+from torch.distributed._shard.sharded_tensor.utils import (
+ _parse_and_validate_remote_device
+)
+from torch.distributed._shard._utils import narrow_tensor
+import torch.distributed as dist
+from typing import List, Union, TYPE_CHECKING
+from ._internals import (
+ get_chunked_dim_size,
+ get_split_size,
+)
+
+from .api import ShardingSpec
+
+if TYPE_CHECKING:
+ # Only include ShardedTensor when do type checking, exclude it
+ # from run-time to resolve circular dependency.
+ from torch.distributed._shard.sharded_tensor import ShardedTensor
+
+@dataclass
+class ChunkShardingSpec(ShardingSpec):
+ """
+ This is a type of PlacementSpec that defines the placement as being sharded
+ across multiple devices. In particular, it represents sharding a Tensor
+ along a single dimension into equal chunks (similar to :meth:`torch.chunk`).
+
+ The semantics of how a tensor is partitioned is inline with
+ :meth:`torch.chunk`, where ``dim`` in torch.chunk corresponds to the
+ specified ``dim`` and ``chunks`` in torch.chunk is the number of elements
+ in the placement specified.
+
+ Args:
+ dim (int or str):
+ The dimension to shard on, could be an integer representing the
+ dimension or a string in case of named tensors where dimensions are
+ named. Note that named tensor support is not added yet.
+ placement(List[Union[_remote_device, str]]):
+ Specifies the placement of each shard of the Tensor. The size of
+ the list represents the number of shards to be created. This could
+ be a list of
+ :class:`torch.distributed._remote_device`'s. This list
+ could also contain a string which represents remote
+ device as accepted by
+ :class:`torch.distributed._remote_device`
+ """
+
+ ShardingDim = Union[int, str]
+
+ dim: ShardingDim
+ placements: List[Union[torch.distributed._remote_device, str]]
+
+ def __post_init__(self):
+ self._verify_dim(self.dim)
+ for i, remote_device in enumerate(self.placements):
+ if not isinstance(remote_device, torch.distributed._remote_device):
+ self.placements[i] = torch.distributed._remote_device(remote_device)
+
+ @staticmethod
+ def _verify_dim(dim):
+ # Validate the sharding spec.
+ # TODO: support named dimension
+ if isinstance(dim, str):
+ raise NotImplementedError(
+ "ChunkShardingSpec does not support named dimension yet!"
+ )
+
+ if not isinstance(dim, int):
+ raise ValueError(
+ f"Sharding dim needs to be an integer, found: {dim}"
+ )
+
+ def build_metadata(self,
+ tensor_sizes: torch.Size,
+ tensor_properties: sharded_tensor_meta.TensorProperties,
+ ) -> sharded_tensor_meta.ShardedTensorMetadata:
+ tensor_num_dim = len(tensor_sizes)
+
+ self._verify_dim(self.dim)
+ if self.dim >= tensor_num_dim or self.dim < -tensor_num_dim: # type: ignore[operator]
+ raise ValueError(f"Invalid sharding dim: {self.dim}")
+
+ shards_metadata = []
+ sharding_dim_size = tensor_sizes[self.dim] # type: ignore[index]
+ chunks = len(self.placements)
+ split_size = get_split_size(sharding_dim_size, chunks)
+ for idx, placement in enumerate(self.placements):
+ # generate ShardMetadata for each placement device
+ chunked_dim_size = get_chunked_dim_size(sharding_dim_size, split_size, idx)
+ if chunked_dim_size > 0:
+ shard_size = list(tensor_sizes)
+ current_offsets = [0] * tensor_num_dim
+ current_offsets[self.dim] = split_size * idx # type: ignore[index]
+ shard_size[self.dim] = chunked_dim_size # type: ignore[index]
+
+ shard_metadata = ShardMetadata(
+ shard_offsets=current_offsets,
+ shard_sizes=shard_size,
+ placement=placement,
+ )
+ shards_metadata.append(shard_metadata)
+
+ # current_offsets[self.dim] += chunked_dim_size # type: ignore[index]
+
+ return sharded_tensor_meta.ShardedTensorMetadata(
+ shards_metadata,
+ tensor_sizes,
+ tensor_properties
+ )
+
+
+ def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> "ShardedTensor":
+ # relative imports to avoid circular dependency
+ from torch.distributed._shard.sharded_tensor import (
+ ShardedTensor
+ )
+ tensor_properties = sharded_tensor_meta.TensorProperties(
+ dtype=tensor.dtype,
+ layout=tensor.layout,
+ requires_grad=tensor.requires_grad,
+ memory_format=torch.contiguous_format,
+ pin_memory=tensor.is_pinned()
+ )
+ current_rank = dist.get_rank(process_group)
+ tensor_meta = self.build_metadata(tensor.size(), tensor_properties)
+ local_shards = []
+ local_tensor = None
+ local_metadata = None
+ tensors_to_scatter = [None] * dist.get_world_size(process_group)
+
+ sharding_dim_size = tensor.size()[self.dim] # type: ignore[index]
+ chunks = len(self.placements)
+ split_size = get_split_size(sharding_dim_size, chunks)
+ scatter_shape = list(tensor.size())
+ scatter_shape[self.dim] = split_size # type: ignore[index]
+
+ for shard_meta in tensor_meta.shards_metadata:
+ rank, device = _parse_and_validate_remote_device(process_group, shard_meta.placement)
+ if current_rank == src_rank:
+ # Reshape to get shard for this rank and we don't want autograd
+ # recording here for the narrow op and 'local_shard' should be a
+ # leaf variable in the autograd graph.
+ narrowed_tensor = narrow_tensor(tensor, shard_meta)
+ if shard_meta.shard_sizes[self.dim] < split_size: # type: ignore[index]
+ # for the last shard that might be smaller to other shards
+ # resize the narrowed tensor to the same size and use it for
+ # the scatter collective as dist.scatter requires same size
+ # inputs on every rank
+ tensor_to_scatter = narrowed_tensor.detach().clone().resize_(scatter_shape)
+ else:
+ tensor_to_scatter = narrowed_tensor.detach().clone().contiguous()
+
+ tensors_to_scatter[rank] = tensor_to_scatter
+
+ if current_rank == rank:
+ local_tensor = torch.empty(
+ scatter_shape, dtype=tensor.dtype, layout=tensor.layout, device=device)
+ local_metadata = shard_meta
+
+ # each rank should have local_tensor and local_metadata initialized if we build
+ # the metadata list in a correct way.
+ assert local_tensor is not None
+ assert local_metadata is not None
+
+ # Scatter the shards to all ranks in the pg
+ dist.scatter(
+ local_tensor,
+ scatter_list=tensors_to_scatter if current_rank == src_rank else None,
+ src=src_rank,
+ group=process_group
+ )
+
+ if list(local_tensor.size()) != local_metadata.shard_sizes:
+ # detach again after receiving to ensure local shards remain a leaf node
+ local_tensor = local_tensor.resize_(local_metadata.shard_sizes).detach()
+
+ # Sync requires_grad to local_shard.
+ local_tensor.requires_grad = tensor.requires_grad
+
+ local_shards.append(Shard(tensor=local_tensor, metadata=local_metadata))
+
+ st = ShardedTensor._init_from_local_shards_and_global_metadata(
+ local_shards,
+ tensor_meta,
+ process_group=process_group)
+
+ # Manually set sharding_spec
+ st._sharding_spec = self
+
+ return st
diff --git a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__init__.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__init__.py
diff --git a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/_common.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/_common.py
new file mode 100644
index 0000000..c2ef8fe
--- /dev/null
+++ b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/_common.py
@@ -0,0 +1,447 @@
+# coding=utf-8
+
+from typing import List
+
+import torch
+import torch.distributed as dist
+import torch.distributed._shard.sharding_spec as shard_spec
+from torch.distributed._shard.sharded_tensor._ops._common import _sharded_op_common
+from torch.distributed._shard.sharded_tensor import (
+ sharded_op_impl,
+ ShardedTensor,
+)
+from torch.distributed._shard.sharding_spec._internals import (
+ get_split_size,
+ get_chunked_dim_size,
+)
+from torch.distributed.nn.functional import (
+ all_gather,
+ all_to_all_single,
+)
+
+
+def _chunk_sharding_spec_check(spec, op):
+ """
+ For the given op implementation check if the sharding spec is ChunkShardingSpec.
+ """
+ if not isinstance(spec, shard_spec.ChunkShardingSpec):
+ raise NotImplementedError(
+ f"Only ChunkShardingSpec supported for '{op.__name__}'."
+ )
+
+def _register_sharded_op_on_local_tensor(
+ op, early_stop_func=None, extra_check=None, customized_func=None
+):
+ """
+ Handles ``__torch_function__`` dispatch for ops which are performed on
+ the single local tensor of the sharded tensor such as op like
+ ``torch.nn.functional.softmax`` or ``torch.Tensor.view``.
+
+ For more complicated ops, a customized func can be used to generate
+ the new local tensor, sharding spec and sharded tensor size.
+
+ Args:
+ op: The op to be registered and applied to all shards of the st.
+ early_stop_func (Callable, optional): the func for early stop.
+ Default: if ``None``, no early stop.
+ extra_check (Callable, optional): the func for extra condition check.
+ Default: if ``None``, no extra check.
+ customized_func (Callable, optional): the func for customized logic
+ to generate the new local tensor, sharding spec and sharded tensor size.
+ Default: if ``None``, we simply lower to the real op call with
+ the single local tensor of the st.
+
+ Return:
+ func (Callable): registered implementation for sharded op for
+ ``__torch_function__`` dispatch.
+ """
+ @sharded_op_impl(op)
+ @_sharded_op_common(op, early_stop_func, extra_check)
+ def sharded_tensor_op_on_local_shards(types, args=(), kwargs=None, pg=None):
+ st = args[0]
+ sharding_spec = st.sharding_spec()
+ _chunk_sharding_spec_check(sharding_spec, op)
+ if len(st.local_shards()) != 1:
+ raise TypeError(
+ f"torch function '{op.__name__}', with args: {args} and "
+ f"kwargs: {kwargs} only supported for single local tensor!"
+ )
+ st_size = st.size()
+ if customized_func:
+ local_tensor, sharding_spec, st_size = customized_func(args, kwargs, pg)
+ else:
+ args = (st.local_tensor(), *args[1:])
+ local_tensor = op(*args, **kwargs)
+ return ShardedTensor._init_from_local_tensor(
+ local_tensor.contiguous(),
+ sharding_spec,
+ st_size, # type: ignore[arg-type]
+ process_group=pg,
+ init_rrefs=st._init_rrefs,
+ )
+
+
+def _handle_col_wise_sharding_base(
+ op_func,
+ col_dim,
+ input,
+ world_size,
+ weight,
+ local_shard,
+ pg,
+ gathered_inputs=None,
+ mode=None,
+ gathered_per_sample_weights=None,
+ gathered_offsets=None,
+ padding_idx=None,
+):
+ """
+ For col-wise sharding of weight, lots of logic are common.
+ So we extract the common logic and put in this function:
+ Step 1. To get input from each rank and
+ Step 2. To perform the op on the concatenated tensor.
+ Step 3. To distribute results to each rank with col rearrangement.
+ Step 4. To concatenate all results from all ranks.
+
+ Args:
+ op_func: operator which is applied to the input tensor.
+ col_dim: dim of result tensor after the operation.
+ input: tensor to be applied op on.
+ world_size: number of ranks.
+ weight: shareded weight tensor.
+ local_shard: col-wise sharded weight tensor.
+ pg: process group.
+ gathered_inputs: list of inputs from all ranks. If specified, we
+ don't need to communicate with each rank any more.
+ mode: aggregation mode of EmbeddingBag.
+ gathered_per_sample_weights: per_sample_weights across all ranks.
+ gathered_offsets: offsets across all ranks.
+ padding_idx: If specified, the entries at padding_idx do
+ not contribute to the gradient; therefore, the embedding
+ vector at padding_idx is not updated during training,
+ i.e. it remains as a fixed “pad”.
+ Note that the embedding vector at padding_idx is
+ excluded from the reduction.
+
+ Return: final result of input being applied with the op.
+ """
+ if gathered_inputs is None:
+ # allgather the inputs first.
+ gathered_inputs = all_gather(input, group=pg)
+
+ # run the operator's function for all the inputs.
+ results = []
+ for i, inp in enumerate(gathered_inputs):
+ if op_func == torch.nn.functional.embedding_bag:
+ result = op_func(
+ inp,
+ local_shard,
+ offsets=gathered_offsets[i] if gathered_offsets is not None else None,
+ mode=mode,
+ per_sample_weights=gathered_per_sample_weights[i]
+ if gathered_per_sample_weights is not None
+ else None,
+ padding_idx=padding_idx,
+ )
+ elif op_func == torch.nn.functional.embedding:
+ result = op_func(
+ inp,
+ local_shard,
+ padding_idx=padding_idx,
+ )
+ else:
+ result = op_func(inp, local_shard)
+ results.append(torch.transpose(result, 0, col_dim))
+
+ # Distribute results to each rank with col rearrangement.
+ output = _result_distribute_with_col_rearrange(
+ results, input, world_size, weight, pg
+ )
+
+ # transpose the output and return result.
+ return torch.transpose(output, 0, col_dim)
+
+
+def _result_distribute_with_col_rearrange(
+ results, input, world_size, weight, pg
+):
+ """
+ For col-wise sharding of weight, we need to distribute
+ results to each rank. We do them in this function.
+ Note that, if the index in the Sharding Spec is not equal to
+ the rank number, we need to do the rearrangement based on the
+ order given by the Sharding Spec (placement).
+
+ Args:
+ results: results from ops applied to inputs from all ranks.
+ We need to distribute them back to their original ranks.
+ input: tensor to be applied op to.
+ world_size: number of ranks.
+ weight: shareded weight tensor.
+ pg: process group.
+
+ Return: column rearranged result.
+ """
+ # Process results and outputs for all2all.
+ sharding_dim = weight._sharding_spec.dim
+ sharding_dim_size = weight.size(sharding_dim)
+ dims = list(results[0].size())
+ dims[0] = sharding_dim_size
+ output = torch.empty(*dims, device=input.device)
+ combined_results = torch.cat(results)
+
+ # Compute output splits
+ split_size = get_split_size(sharding_dim_size, world_size)
+ output_split_sizes = [0] * world_size
+ for idx, placement in enumerate(weight._sharding_spec.placements):
+ output_split_sizes[placement.rank()] = get_chunked_dim_size(
+ sharding_dim_size, split_size, idx
+ )
+
+ # distribute the outputs using all2all.
+ output = all_to_all_single(
+ output, combined_results, output_split_sizes=output_split_sizes, group=pg
+ )
+
+ # Check if we need to rearrange columns appropriately for output.
+ rearrange_columns = any(
+ [
+ idx != placement.rank()
+ for idx, placement in enumerate(weight._sharding_spec.placements)
+ ]
+ )
+ if not rearrange_columns:
+ return output
+
+ indices = []
+ for placement in weight._sharding_spec.placements:
+ dim_size = output_split_sizes[placement.rank()]
+ start = sum(
+ [
+ split_size if i < placement.rank() else 0
+ for i, split_size in enumerate(output_split_sizes)
+ ]
+ )
+ indices += list(range(start, start + dim_size))
+
+ return output.index_select(0, torch.tensor(indices, device=output.device))
+
+
+def _handle_row_wise_lookup_distribute(
+ input_sorted, input, world_size, weight, rank, padding_idx
+):
+ """
+ In the circumstance of row-wise sharding of weight, we need to distribute
+ the sorted lookup IDs of embedding/embeddingBag to each rank.
+ If the index in the placement is not equal to the rank number, we need to
+ do the rearrangement based on the order given by the Sharding Spec (placement).
+
+ In addition, we do two things for padding_idx. The first thing is to only
+ set it if it's within the range of the current rank and the other thing
+ is to do the modularization of it by sharded_dim_size_max.
+
+ Args:
+ input_sorted: sorted lookup IDs of embedding/embeddingBag.
+ input: tensor to be applied op on.
+ world_size: number of ranks.
+ weight: shareded weight tensor.
+ rank: # of cuda process.
+ padding_idx: If specified, the entries at padding_idx do
+ not contribute to the gradient and reduction.
+
+ Return:
+ input_sorted: sorted lookup IDs of embedding/embeddingBag
+ Rearrangement performed if it is needed.
+ input_split_sizes: size of IDs to be assigned to each rank.
+ sharded_dim_size_max: the max size of the row each rank gets.
+ input_split_rearrange_indices: indices of row rearrangement.
+ rearrange_indices_1d_second_order: reverse indices of row
+ rearrangement, which will be used to restore the original
+ order.
+ padding_idx: Same as input if padding_idx is within the range
+ of the given rank; otherwise, None is returned. It is
+ also modularized by sharded_dim_size_max.
+ """
+ # Decide which rank the input goes to by check the sharding range.
+ split_size = get_split_size(weight.size(0), world_size)
+ rearrange_rows = False
+ indices_flatten = None
+ input_split_sizes: List[int] = [0] * world_size
+ input_split_start_indices: List[int] = [0] * world_size
+ start_row_idx_rank = None
+ end_row_idx_rank = None
+ # When we do the chunk split, we always ensure the first N - 1 chunks get max out
+ # and then the Nth chunk gets the rest. So input_split_sizes like [3, 3, 3, 4]
+ # are not possible. The expected split size will be [4, 4, 4, 1].
+ sharded_dim_size_max = get_chunked_dim_size(weight.size(0), split_size, 0)
+ for idx, placement in enumerate(weight._sharding_spec.placements):
+ sharded_dim_size = get_chunked_dim_size(weight.size(0), split_size, idx)
+ start_row_idx = idx * sharded_dim_size_max
+ end_row_idx = start_row_idx + sharded_dim_size
+ start_idx = torch.searchsorted(input_sorted, start_row_idx).item()
+ end_idx = torch.searchsorted(input_sorted, end_row_idx).item()
+ input_split_sizes[placement.rank()] = int(end_idx - start_idx)
+ input_split_start_indices[placement.rank()] = int(start_idx)
+ if placement.rank() != idx:
+ rearrange_rows = True
+ # Store the range of the current rank.
+ if placement.rank() == rank:
+ start_row_idx_rank = start_row_idx
+ end_row_idx_rank = end_row_idx
+
+ # Perform the modular if padding_idx is within the range.
+ if padding_idx is not None:
+ if padding_idx < start_row_idx_rank or padding_idx >= end_row_idx_rank:
+ padding_idx = None
+ else:
+ padding_idx = padding_idx % sharded_dim_size_max
+
+ rearrange_indices_1d_second_order = None
+ if rearrange_rows:
+ # Need to re-arrange the 1D tensor to be sent via all2all.
+ indices: List[List[int]] = [[0]] * world_size
+ for placement in weight._sharding_spec.placements:
+ split_length = input_split_sizes[placement.rank()]
+ offset_idx = input_split_start_indices[placement.rank()]
+ indices[placement.rank()] = list(
+ range(offset_idx, offset_idx + split_length)
+ )
+ indices_flatten = list(idx for indice in indices for idx in indice)
+
+ input_sorted = input_sorted.index_select(
+ 0, torch.tensor(indices_flatten, device=input.device)
+ )
+ rearrange_indices_1d_second_order = torch.argsort(torch.Tensor(indices_flatten))
+
+ return (
+ input_sorted,
+ input_split_sizes,
+ sharded_dim_size_max,
+ torch.tensor(indices_flatten, device=input.device) if rearrange_rows else None,
+ rearrange_indices_1d_second_order,
+ padding_idx,
+ )
+
+
+def _communicate_size_to_each_rank(
+ input_size_list, output_size, input, pg, tensor_type=torch.int
+):
+ """
+ In the circumstance of row-wise sharding of weight, we need to first
+ communicate the input length to each rank because each rank gets a
+ different one.
+
+ Args:
+ input_size_list: list of sizes to be sent to each rank.
+ output_size: length of the output tensor.
+ input: tensor to be applied op on.
+ pg: process group.
+ tensor_type: dtype of tensor.
+
+ Return: A list of communication results (int).
+ """
+ input_size_list_tensor = torch.tensor(
+ input_size_list, dtype=tensor_type, device=input.device
+ )
+ output_size_list_tensor = torch.empty(
+ output_size, dtype=tensor_type, device=input.device
+ )
+ dist.all_to_all_single(
+ output_size_list_tensor,
+ input_size_list_tensor,
+ group=pg,
+ )
+ return output_size_list_tensor.tolist()
+
+
+def _communicate_list_to_each_rank(
+ input_tensor_list, output_lists, input, pg, tensor_type=torch.int64
+):
+ """
+ In the circumstance of row-wise sharding of weight, we need to
+ communicate a list of input tensors to each rank. Because the
+ input could be a list of list, we need to first convert the list
+ to a tensor.
+
+ Args:
+ input_tensor_list: list of tensors to be sent to each rank.
+ output_lists: list of sizes to be obtained from each rank.
+ input: tensor to be applied op on.
+ pg: process group.
+ tensor_type: dtype of tensor.
+
+ Return: A list of communication results (tensors).
+ """
+ output_tensor_list = []
+ for output_list in output_lists:
+ output_tensor_list.append(
+ torch.empty(output_list, dtype=tensor_type, device=input.device)
+ )
+ dist.all_to_all(
+ output_tensor_list,
+ input_tensor_list,
+ group=pg,
+ )
+ return output_tensor_list
+
+
+def _handle_max_norm_col_wise(
+ max_norm,
+ norm_type,
+ local_shard,
+ input,
+ world_size,
+ pg,
+):
+ """
+ For col-wise sharding of weight, we need to aggregate the
+ norm across all ranks before we can perform the proper re-norm.
+ Note that, the max_norm logic is only applied to the embedding
+ indices that are looked up and not the whole shard.
+
+ Args:
+ max_norm: If given, each embedding vector with norm larger
+ than max_norm is renormalized to have norm max_norm.
+ Note: this will modify weight in-place.
+ norm_type: The p in the p-norm to compute for the max_norm option.
+ local_shard: col-wise shared local weight used for lookup.
+ input: tensor to be applied op to.
+ world_size: number of ranks.
+ pg: process group.
+
+ Return:
+ local_shard_norm_renormed: local_shard re-normed to max_norm if the norm is larger
+ than it.
+ gathered_inputs: list of inputs from all ranks.
+ """
+ norm_type = norm_type if norm_type is not None else 2.0
+ # allgather the inputs first.
+ gathered_inputs = [torch.zeros_like(input) for _ in range(world_size)]
+ dist.all_gather(gathered_inputs, input, group=pg)
+ unique_inp = torch.unique(torch.cat(gathered_inputs))
+ local_shard_sum = torch.sum(
+ torch.pow(torch.abs(local_shard), norm_type), dim=1, dtype=local_shard.dtype
+ )
+ # For col-wise sharding, we need to first aggregate the powered sum
+ # from each rank first and then calculate the norm.
+ dist.all_reduce(local_shard_sum, group=pg)
+ local_shard_norm = torch.pow(local_shard_sum, 1.0 / norm_type)
+ max_norm_tensor = torch.full(
+ (local_shard.size(0),),
+ float("inf"),
+ dtype=local_shard.dtype,
+ device=input.device,
+ )
+ max_norm_tensor[unique_inp] = max_norm
+ local_shard_t = local_shard.t().contiguous()
+ normalized_tensor = torch.where(
+ local_shard_norm > max_norm_tensor, max_norm_tensor, local_shard_norm
+ )
+ # Make sure divisor is not zero.
+ local_shard_norm[local_shard_norm == 0.0] = 1.0
+ local_shard_norm_renormed = (
+ torch.div(torch.mul(local_shard_t, normalized_tensor), local_shard_norm)
+ .t()
+ .contiguous()
+ )
+ return local_shard_norm_renormed, gathered_inputs
diff --git a/torch/distributed/_shard/sharded_tensor/_ops/embedding.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding.py
similarity index 97%
rename from torch/distributed/_shard/sharded_tensor/_ops/embedding.py
rename to torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding.py
index 11ef153..232e716 100644
--- a/torch/distributed/_shard/sharded_tensor/_ops/embedding.py
+++ b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding.py
@@ -11,13 +11,13 @@
_handle_max_norm_col_wise,
)
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
+from torch.distributed._shard.sharding_spec.api import custom_sharding_spec_op
from torch.distributed._shard.sharded_tensor import (
- sharded_op_impl,
ShardedTensor
)
-@sharded_op_impl(torch.nn.functional.embedding)
-def sharded_embedding(types, args, kwargs, pg):
+@custom_sharding_spec_op(ChunkShardingSpec, torch.nn.functional.embedding)
+def sharded_embedding(types, args, kwargs):
"""
Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding``.
This method computes a sharded embedding lookup and has the following limitations:
@@ -104,6 +104,7 @@
norm_type = kwargs.get("norm_type")
padding_idx = kwargs.get("padding_idx")
+ pg = weight._process_group
local_shard = weight.local_tensor().contiguous()
sharding_dim = weight._sharding_spec.dim
world_size = dist.get_world_size(pg)
diff --git a/torch/distributed/_shard/sharded_tensor/_ops/embedding_bag.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py
similarity index 98%
rename from torch/distributed/_shard/sharded_tensor/_ops/embedding_bag.py
rename to torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py
index 6fad875..d3151a3 100644
--- a/torch/distributed/_shard/sharded_tensor/_ops/embedding_bag.py
+++ b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py
@@ -15,14 +15,14 @@
_handle_max_norm_col_wise,
)
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
+from torch.distributed._shard.sharding_spec.api import custom_sharding_spec_op
from torch.distributed._shard.sharded_tensor import (
- sharded_op_impl,
ShardedTensor
)
-@sharded_op_impl(torch.nn.functional.embedding_bag)
-def sharded_embedding_bag(types, args, kwargs, pg):
+@custom_sharding_spec_op(ChunkShardingSpec, torch.nn.functional.embedding_bag)
+def sharded_embedding_bag(types, args, kwargs):
"""
Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding_bag``.
This method computes a sharded embedding bag aggregation and has the following limitations:
@@ -123,6 +123,7 @@
include_last_offset = kwargs.get("include_last_offset")
padding_idx = kwargs.get("padding_idx")
+ pg = weight._process_group
local_shard = weight.local_tensor().contiguous()
sharding_dim = weight._sharding_spec.dim
world_size = dist.get_world_size(pg)
diff --git a/torch/distributed/_shard/sharded_tensor/_ops/linear.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/linear.py
similarity index 98%
rename from torch/distributed/_shard/sharded_tensor/_ops/linear.py
rename to torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/linear.py
index 560ba91..87eefd9 100644
--- a/torch/distributed/_shard/sharded_tensor/_ops/linear.py
+++ b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/linear.py
@@ -9,10 +9,10 @@
)
from torch.distributed._shard.partial_tensor import _PartialTensor
from torch.distributed._shard.sharded_tensor import (
- sharded_op_impl,
ShardedTensor,
)
from torch.distributed._shard.sharding_spec import ChunkShardingSpec
+from torch.distributed._shard.sharding_spec.api import custom_sharding_spec_op
from torch.distributed._shard.sharding_spec._internals import (
get_split_size,
get_chunked_dim_size,
@@ -24,8 +24,8 @@
)
-@sharded_op_impl(torch.nn.functional.linear)
-def sharded_linear(types, args, kwargs, pg):
+@custom_sharding_spec_op(ChunkShardingSpec, torch.nn.functional.linear)
+def sharded_linear(types, args, kwargs):
"""
Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
This method computes a sharded linear and has the following limitations:
@@ -99,6 +99,7 @@
weight = args[1]
bias = args[2]
+ pg = weight._process_group
local_shard = weight.local_tensor()
local_shard_t = local_shard.t().contiguous()
sharding_dim = weight._sharding_spec.dim
diff --git a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/math_ops.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/math_ops.py
new file mode 100644
index 0000000..5f5b59f
--- /dev/null
+++ b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/math_ops.py
@@ -0,0 +1,72 @@
+import torch
+from torch import Tensor
+from torch.distributed._shard.sharded_tensor import (
+ ShardedTensor,
+)
+from torch.distributed._shard.sharding_spec import ChunkShardingSpec
+from torch.distributed._shard.sharding_spec.api import custom_sharding_spec_op
+from torch.distributed._shard.sharded_tensor._ops.math_ops import binary_math_op_impl
+
+from ._common import (
+ _chunk_sharding_spec_check,
+)
+
+def register_math_op(op):
+ @custom_sharding_spec_op(ChunkShardingSpec, op)
+ def binary_math_op(types, args=(), kwargs=None):
+ """
+ Handles ``__torch_function__`` dispatch for the binary math ops
+ such as `torch.add`, `torch.mul`, `torch.div`, etc.
+ This method computes on ShardedTensor
+ """
+ if len(args) != 2:
+ raise ValueError("Only support binary math op on ShardedTensor for now!")
+ lhs = args[0]
+ rhs = args[1]
+ pg = lhs._process_group if isinstance(lhs, ShardedTensor) else rhs._process_group
+ # Validate types
+ if isinstance(lhs, ShardedTensor) and isinstance(rhs, ShardedTensor):
+ lhs_spec = lhs.sharding_spec()
+ rhs_spec = rhs.sharding_spec()
+ _chunk_sharding_spec_check(lhs_spec, op)
+ _chunk_sharding_spec_check(rhs_spec, op)
+
+ if lhs.size() == rhs.size() and lhs_spec.dim == rhs_spec.dim: # type: ignore[attr-defined]
+ # perform local element-wise math op
+ res = op(lhs.local_tensor(), rhs.local_tensor())
+ return ShardedTensor._init_from_local_tensor(
+ res,
+ lhs_spec,
+ lhs.size(), # type: ignore[arg-type]
+ process_group=pg)
+ else:
+ raise RuntimeError("Implicit broadcasting not supported yet!")
+ else:
+ # Try dispatch to ShardingSpec agnostic ops.
+ return binary_math_op_impl(op, types, args, kwargs, pg)
+
+binary_ops = [
+ # add
+ torch.add,
+ Tensor.add,
+ Tensor.__add__,
+ Tensor.__radd__,
+ # sub
+ torch.sub,
+ Tensor.sub,
+ Tensor.__sub__,
+ Tensor.__rsub__,
+ # mul
+ torch.mul,
+ Tensor.mul,
+ Tensor.__mul__,
+ Tensor.__rmul__,
+ # div
+ torch.div,
+ Tensor.div,
+ Tensor.__div__,
+ Tensor.__rdiv__,
+]
+
+for op in binary_ops:
+ register_math_op(op)
diff --git a/torch/distributed/_shard/sharded_tensor/_ops/matrix_ops.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/matrix_ops.py
similarity index 80%
rename from torch/distributed/_shard/sharded_tensor/_ops/matrix_ops.py
rename to torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/matrix_ops.py
index 4fdf9a3..5bf3fb2 100644
--- a/torch/distributed/_shard/sharded_tensor/_ops/matrix_ops.py
+++ b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/matrix_ops.py
@@ -4,73 +4,14 @@
import torch
import torch.distributed as dist
from torch.distributed._shard.sharded_tensor import (
- Shard,
ShardedTensor,
)
from ._common import (
_chunk_sharding_spec_check,
_register_sharded_op_on_local_tensor,
- _register_sharded_op_on_local_shards,
)
-
-def sharded_type_as_check(*args, **kwargs):
- """
- Perform extra checks for the sharded_type_as op such as the input needs to
- be either a Tensor or ShardedTensor.
-
- Args: same as ``torch.Tensor.type_as``.
-
- Return: None
- """
- if len(args) < 2:
- raise ValueError("Needs to give a tensor to cast type as!")
- if not isinstance(args[1], torch.Tensor) and not isinstance(args[1], ShardedTensor):
- raise ValueError("Needs to give a Tensor or ShardedTensor to cast type as!")
-
-
-def same_dtype(*args, **kwargs):
- """
- When the dtype is the same, return the original ShardedTensor.
-
- Args: same as ``torch.Tensor.type_as``.
-
- Return (bool): Whether to return early or not.
- """
- return args[0].dtype == args[1].dtype
-
-
-def sharded_type_as(args, kwargs, pg):
- """
- Handles ``__torch_function__`` dispatch for the ``torch.Tensor.type_as`` op.
-
- Args: same as ``torch.Tensor.type_as``.
-
- Return:
- new_local_shards (List[Shard]): Local shards for the new sharded tensor.
- st_meta (ShardedTensorMetadata): Metadata of the new sharded tensor.
- """
- st = args[0]
- tensor = args[1]
- if isinstance(tensor, ShardedTensor):
- tensor = tensor.local_tensor()
- new_local_shards = []
- for shard in st.local_shards():
- new_local_shards.append(Shard(shard.tensor.type_as(tensor), shard.metadata))
- st_meta = copy.deepcopy(st._metadata)
- st_meta.tensor_properties.dtype = tensor.dtype
- return new_local_shards, st_meta
-
-
-_register_sharded_op_on_local_shards(
- torch.Tensor.type_as,
- early_stop_func=same_dtype,
- extra_check=sharded_type_as_check,
- customized_func=sharded_type_as,
-)
-
-
def transpose_same_dim(*args, **kwargs):
"""
When the dim0 and dim1 of transpose are the same, return the original ShardedTensor.
@@ -322,3 +263,68 @@
extra_check=sharded_view_check,
customized_func=sharded_view,
)
+
+def sharded_bmm_check(*args, **kwargs):
+ """
+ Perform extra checks for the sharded_bmm op, for example, st2 needs to
+ be a sharded tensor and both tensors need to sharded by dim 0, etc.
+
+ Args: same as ``torch.bmm``.
+
+ Return: None
+ """
+ if len(args) < 2:
+ raise TypeError("Needs two tensors to perform torch.bmm.")
+ st = args[0]
+ st2 = args[1]
+ # Validate types
+ if not isinstance(st2, ShardedTensor):
+ raise TypeError("st2 needs to be a ShardedTensor for torch.bmm.")
+ _chunk_sharding_spec_check(st2.sharding_spec(), torch.bmm)
+ if st.dim() != 3 or st2.dim() != 3:
+ raise TypeError("both st and st2 need to be a 3D ShardedTensor")
+ if (
+ st.sharding_spec().dim != st2.sharding_spec().dim # type: ignore[attr-defined]
+ or st.sharding_spec().dim != 0
+ ):
+ raise NotImplementedError(
+ "Only support performing bmm on tensors sharded on dim 0 now."
+ )
+ if st.sharding_spec().placements != st2.sharding_spec().placements: # type: ignore[attr-defined]
+ raise NotImplementedError(
+ "Both st and st2 need to have same placements for bmm."
+ )
+
+def sharded_bmm(args, kwargs, pg):
+ """
+ Handles ``__torch_function__`` dispatch for the sharded_bmm op.
+
+ Warning: For now we only supports the case when both tensors are sharded
+ by dim 0 so that no local communication.
+
+ Args: same as ``torch.bmm``.
+
+ Return:
+ local_tensor (Tensor): New local tensor to build the sharded tensor.
+ sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`):
+ sharding spec of the new sharded tensor.
+ new_st_size (torch.Size): Size of the new sharded tensor.
+ """
+ st = args[0]
+ st2 = args[1]
+ local_tensor = torch.bmm(st.local_tensor(), st2.local_tensor())
+ new_st_size = (*st.size()[:-1], st2.size(-1))
+ return local_tensor, st.sharding_spec(), new_st_size
+
+
+_register_sharded_op_on_local_tensor(
+ torch.Tensor.bmm,
+ extra_check=sharded_bmm_check,
+ customized_func=sharded_bmm,
+)
+
+_register_sharded_op_on_local_tensor(
+ torch.bmm,
+ extra_check=sharded_bmm_check,
+ customized_func=sharded_bmm,
+)