Allow for custom sharding specs to register their own ops.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/76360

Customized ShardingSpecs could be entirely arbitrary and it would not
be possible to handle ops for those as a result since they might not fit into
the patterns supported by the in-built ShardingSpecs. As a result, we introduce
a framework for a ShardingSpec to override ops as follows:

1) In the dispatch system, if a ShardingSpec has a customized op the registered
op for that ShardingSpec is invoked.
2) As a result, all ChunkShardingSpec specific ops have been moved under that
ShardingSpec.
3) There will be a set of ShardingSpec agnostic ops (ex: elementwise ops) which
will be a set of common ops supported across any ShardingSpec.
4) If an op is not found for a particular ShardingSpec the default set of ops
is searched for that op.

Differential Revision: [D35917912](https://our.internmc.facebook.com/intern/diff/D35917912/)

Approved by: https://github.com/wanchaol
diff --git a/test/distributed/_shard/sharded_tensor/ops/test_math_ops.py b/test/distributed/_shard/sharded_tensor/ops/test_math_ops.py
index b6c9930..f3129f0 100644
--- a/test/distributed/_shard/sharded_tensor/ops/test_math_ops.py
+++ b/test/distributed/_shard/sharded_tensor/ops/test_math_ops.py
@@ -129,7 +129,7 @@
 
         st = sharded_tensor.rand(spec, 10, 10)
 
-        with self.assertRaisesRegex(TypeError, 'with ChunkShardingSpec supports'):
+        with self.assertRaisesRegex(RuntimeError, 'not supported'):
             torch.add(st, sharded_rhs)
 
 
diff --git a/torch/distributed/_shard/sharded_tensor/_ops/__init__.py b/torch/distributed/_shard/sharded_tensor/_ops/__init__.py
index 2545ccd..f59c8d3 100644
--- a/torch/distributed/_shard/sharded_tensor/_ops/__init__.py
+++ b/torch/distributed/_shard/sharded_tensor/_ops/__init__.py
@@ -1,10 +1,13 @@
 import torch.distributed._shard.sharded_tensor._ops.chunk
 import torch.distributed._shard.sharded_tensor._ops.elementwise_ops
 import torch.distributed._shard.sharded_tensor._ops.math_ops
-import torch.distributed._shard.sharded_tensor._ops.matrix_ops
 
 from .binary_cmp import equal, allclose
-from .embedding import sharded_embedding
-from .embedding_bag import sharded_embedding_bag
 from .init import kaiming_uniform_, normal_, uniform_, constant_
-from .linear import sharded_linear
+
+# Import all ChunkShardingSpec ops
+from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.linear import sharded_linear
+from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.embedding import sharded_embedding
+from torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.embedding_bag import sharded_embedding_bag
+import torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.math_ops
+import torch.distributed._shard.sharding_spec.chunk_sharding_spec_ops.matrix_ops
diff --git a/torch/distributed/_shard/sharded_tensor/_ops/_common.py b/torch/distributed/_shard/sharded_tensor/_ops/_common.py
index 3ef25a2..b5dc61f 100644
--- a/torch/distributed/_shard/sharded_tensor/_ops/_common.py
+++ b/torch/distributed/_shard/sharded_tensor/_ops/_common.py
@@ -1,35 +1,9 @@
-# coding=utf-8
-
 import functools
-from typing import List
-
-import torch
-import torch.distributed as dist
-import torch.distributed._shard.sharding_spec as shard_spec
 from torch.distributed._shard.sharded_tensor import (
     sharded_op_impl,
     Shard,
     ShardedTensor,
 )
-from torch.distributed._shard.sharding_spec._internals import (
-    get_split_size,
-    get_chunked_dim_size,
-)
-from torch.distributed.nn.functional import (
-    all_gather,
-    all_to_all_single,
-)
-
-
-def _chunk_sharding_spec_check(spec, op):
-    """
-    For the given op implementation check if the sharding spec is ChunkShardingSpec.
-    """
-    if not isinstance(spec, shard_spec.ChunkShardingSpec):
-        raise NotImplementedError(
-            f"Only ChunkShardingSpec supported for '{op.__name__}'."
-        )
-
 
 def _sharded_op_common(op, early_stop_func, extra_check):
     """
@@ -84,7 +58,6 @@
 
     return decorator_sharded_func
 
-
 def _register_sharded_op_on_local_shards(
     op, early_stop_func=None, extra_check=None, customized_func=None
 ):
@@ -132,421 +105,3 @@
             process_group=pg,
             init_rrefs=st._init_rrefs,
         )
-
-
-def _register_sharded_op_on_local_tensor(
-    op, early_stop_func=None, extra_check=None, customized_func=None
-):
-    """
-    Handles ``__torch_function__`` dispatch for ops which are performed on
-    the single local tensor of the sharded tensor such as op like
-    ``torch.nn.functional.softmax`` or ``torch.Tensor.view``.
-
-    For more complicated ops, a customized func can be used to generate
-    the new local tensor, sharding spec and sharded tensor size.
-
-    Args:
-        op: The op to be registered and applied to all shards of the st.
-        early_stop_func (Callable, optional): the func for early stop.
-            Default: if ``None``, no early stop.
-        extra_check (Callable, optional): the func for extra condition check.
-            Default: if ``None``, no extra check.
-        customized_func (Callable, optional): the func for customized logic
-            to generate the new local tensor, sharding spec and sharded tensor size.
-            Default: if ``None``, we simply lower to the real op call with
-                the single local tensor of the st.
-
-    Return:
-        func (Callable): registered implementation for sharded op for
-        ``__torch_function__`` dispatch.
-    """
-    @sharded_op_impl(op)
-    @_sharded_op_common(op, early_stop_func, extra_check)
-    def sharded_tensor_op_on_local_shards(types, args=(), kwargs=None, pg=None):
-        st = args[0]
-        sharding_spec = st.sharding_spec()
-        _chunk_sharding_spec_check(sharding_spec, op)
-        if len(st.local_shards()) != 1:
-            raise TypeError(
-                f"torch function '{op.__name__}', with args: {args} and "
-                f"kwargs: {kwargs} only supported for single local tensor!"
-            )
-        st_size = st.size()
-        if customized_func:
-            local_tensor, sharding_spec, st_size = customized_func(args, kwargs, pg)
-        else:
-            args = (st.local_tensor(), *args[1:])
-            local_tensor = op(*args, **kwargs)
-        return ShardedTensor._init_from_local_tensor(
-            local_tensor.contiguous(),
-            sharding_spec,
-            st_size,  # type: ignore[arg-type]
-            process_group=pg,
-            init_rrefs=st._init_rrefs,
-        )
-
-
-def _handle_col_wise_sharding_base(
-    op_func,
-    col_dim,
-    input,
-    world_size,
-    weight,
-    local_shard,
-    pg,
-    gathered_inputs=None,
-    mode=None,
-    gathered_per_sample_weights=None,
-    gathered_offsets=None,
-    padding_idx=None,
-):
-    """
-    For col-wise sharding of weight, lots of logic are common.
-    So we extract the common logic and put in this function:
-    Step 1. To get input from each rank and
-    Step 2. To perform the op on the concatenated tensor.
-    Step 3. To distribute results to each rank with col rearrangement.
-    Step 4. To concatenate all results from all ranks.
-
-    Args:
-        op_func: operator which is applied to the input tensor.
-        col_dim: dim of result tensor after the operation.
-        input: tensor to be applied op on.
-        world_size: number of ranks.
-        weight: shareded weight tensor.
-        local_shard: col-wise sharded weight tensor.
-        pg: process group.
-        gathered_inputs: list of inputs from all ranks. If specified, we
-            don't need to communicate with each rank any more.
-        mode: aggregation mode of EmbeddingBag.
-        gathered_per_sample_weights: per_sample_weights across all ranks.
-        gathered_offsets: offsets across all ranks.
-        padding_idx: If specified, the entries at padding_idx do
-            not contribute to the gradient; therefore, the embedding
-            vector at padding_idx is not updated during training,
-            i.e. it remains as a fixed “pad”.
-            Note that the embedding vector at padding_idx is
-            excluded from the reduction.
-
-    Return: final result of input being applied with the op.
-    """
-    if gathered_inputs is None:
-        # allgather the inputs first.
-        gathered_inputs = all_gather(input, group=pg)
-
-    # run the operator's function for all the inputs.
-    results = []
-    for i, inp in enumerate(gathered_inputs):
-        if op_func == torch.nn.functional.embedding_bag:
-            result = op_func(
-                inp,
-                local_shard,
-                offsets=gathered_offsets[i] if gathered_offsets is not None else None,
-                mode=mode,
-                per_sample_weights=gathered_per_sample_weights[i]
-                if gathered_per_sample_weights is not None
-                else None,
-                padding_idx=padding_idx,
-            )
-        elif op_func == torch.nn.functional.embedding:
-            result = op_func(
-                inp,
-                local_shard,
-                padding_idx=padding_idx,
-            )
-        else:
-            result = op_func(inp, local_shard)
-        results.append(torch.transpose(result, 0, col_dim))
-
-    # Distribute results to each rank with col rearrangement.
-    output = _result_distribute_with_col_rearrange(
-        results, input, world_size, weight, pg
-    )
-
-    # transpose the output and return result.
-    return torch.transpose(output, 0, col_dim)
-
-
-def _result_distribute_with_col_rearrange(
-    results, input, world_size, weight, pg
-):
-    """
-    For col-wise sharding of weight, we need to distribute
-    results to each rank. We do them in this function.
-    Note that, if the index in the Sharding Spec is not equal to
-    the rank number, we need to do the rearrangement based on the
-    order given by the Sharding Spec (placement).
-
-    Args:
-        results: results from ops applied to inputs from all ranks.
-            We need to distribute them back to their original ranks.
-        input: tensor to be applied op to.
-        world_size: number of ranks.
-        weight: shareded weight tensor.
-        pg: process group.
-
-    Return: column rearranged result.
-    """
-    # Process results and outputs for all2all.
-    sharding_dim = weight._sharding_spec.dim
-    sharding_dim_size = weight.size(sharding_dim)
-    dims = list(results[0].size())
-    dims[0] = sharding_dim_size
-    output = torch.empty(*dims, device=input.device)
-    combined_results = torch.cat(results)
-
-    # Compute output splits
-    split_size = get_split_size(sharding_dim_size, world_size)
-    output_split_sizes = [0] * world_size
-    for idx, placement in enumerate(weight._sharding_spec.placements):
-        output_split_sizes[placement.rank()] = get_chunked_dim_size(
-            sharding_dim_size, split_size, idx
-        )
-
-    # distribute the outputs using all2all.
-    output = all_to_all_single(
-        output, combined_results, output_split_sizes=output_split_sizes, group=pg
-    )
-
-    # Check if we need to rearrange columns appropriately for output.
-    rearrange_columns = any(
-        [
-            idx != placement.rank()
-            for idx, placement in enumerate(weight._sharding_spec.placements)
-        ]
-    )
-    if not rearrange_columns:
-        return output
-
-    indices = []
-    for placement in weight._sharding_spec.placements:
-        dim_size = output_split_sizes[placement.rank()]
-        start = sum(
-            [
-                split_size if i < placement.rank() else 0
-                for i, split_size in enumerate(output_split_sizes)
-            ]
-        )
-        indices += list(range(start, start + dim_size))
-
-    return output.index_select(0, torch.tensor(indices, device=output.device))
-
-
-def _handle_row_wise_lookup_distribute(
-    input_sorted, input, world_size, weight, rank, padding_idx
-):
-    """
-    In the circumstance of row-wise sharding of weight, we need to distribute
-    the sorted lookup IDs of embedding/embeddingBag to each rank.
-    If the index in the placement is not equal to the rank number, we need to
-    do the rearrangement based on the order given by the Sharding Spec (placement).
-
-    In addition, we do two things for padding_idx. The first thing is to only
-    set it if it's within the range of the current rank and the other thing
-    is to do the modularization of it by sharded_dim_size_max.
-
-    Args:
-        input_sorted: sorted lookup IDs of embedding/embeddingBag.
-        input: tensor to be applied op on.
-        world_size: number of ranks.
-        weight: shareded weight tensor.
-        rank: # of cuda process.
-        padding_idx: If specified, the entries at padding_idx do
-            not contribute to the gradient and reduction.
-
-    Return:
-        input_sorted: sorted lookup IDs of embedding/embeddingBag
-            Rearrangement performed if it is needed.
-        input_split_sizes: size of IDs to be assigned to each rank.
-        sharded_dim_size_max: the max size of the row each rank gets.
-        input_split_rearrange_indices: indices of row rearrangement.
-        rearrange_indices_1d_second_order: reverse indices of row
-            rearrangement, which will be used to restore the original
-            order.
-        padding_idx: Same as input if padding_idx is within the range
-            of the given rank; otherwise, None is returned. It is
-            also modularized by sharded_dim_size_max.
-    """
-    # Decide which rank the input goes to by check the sharding range.
-    split_size = get_split_size(weight.size(0), world_size)
-    rearrange_rows = False
-    indices_flatten = None
-    input_split_sizes: List[int] = [0] * world_size
-    input_split_start_indices: List[int] = [0] * world_size
-    start_row_idx_rank = None
-    end_row_idx_rank = None
-    # When we do the chunk split, we always ensure the first N - 1 chunks get max out
-    # and then the Nth chunk gets the rest. So input_split_sizes like [3, 3, 3, 4]
-    # are not possible. The expected split size will be [4, 4, 4, 1].
-    sharded_dim_size_max = get_chunked_dim_size(weight.size(0), split_size, 0)
-    for idx, placement in enumerate(weight._sharding_spec.placements):
-        sharded_dim_size = get_chunked_dim_size(weight.size(0), split_size, idx)
-        start_row_idx = idx * sharded_dim_size_max
-        end_row_idx = start_row_idx + sharded_dim_size
-        start_idx = torch.searchsorted(input_sorted, start_row_idx).item()
-        end_idx = torch.searchsorted(input_sorted, end_row_idx).item()
-        input_split_sizes[placement.rank()] = int(end_idx - start_idx)
-        input_split_start_indices[placement.rank()] = int(start_idx)
-        if placement.rank() != idx:
-            rearrange_rows = True
-        # Store the range of the current rank.
-        if placement.rank() == rank:
-            start_row_idx_rank = start_row_idx
-            end_row_idx_rank = end_row_idx
-
-    # Perform the modular if padding_idx is within the range.
-    if padding_idx is not None:
-        if padding_idx < start_row_idx_rank or padding_idx >= end_row_idx_rank:
-            padding_idx = None
-        else:
-            padding_idx = padding_idx % sharded_dim_size_max
-
-    rearrange_indices_1d_second_order = None
-    if rearrange_rows:
-        # Need to re-arrange the 1D tensor to be sent via all2all.
-        indices: List[List[int]] = [[0]] * world_size
-        for placement in weight._sharding_spec.placements:
-            split_length = input_split_sizes[placement.rank()]
-            offset_idx = input_split_start_indices[placement.rank()]
-            indices[placement.rank()] = list(
-                range(offset_idx, offset_idx + split_length)
-            )
-        indices_flatten = list(idx for indice in indices for idx in indice)
-
-        input_sorted = input_sorted.index_select(
-            0, torch.tensor(indices_flatten, device=input.device)
-        )
-        rearrange_indices_1d_second_order = torch.argsort(torch.Tensor(indices_flatten))
-
-    return (
-        input_sorted,
-        input_split_sizes,
-        sharded_dim_size_max,
-        torch.tensor(indices_flatten, device=input.device) if rearrange_rows else None,
-        rearrange_indices_1d_second_order,
-        padding_idx,
-    )
-
-
-def _communicate_size_to_each_rank(
-    input_size_list, output_size, input, pg, tensor_type=torch.int
-):
-    """
-    In the circumstance of row-wise sharding of weight, we need to first
-    communicate the input length to each rank because each rank gets a
-    different one.
-
-    Args:
-        input_size_list: list of sizes to be sent to each rank.
-        output_size: length of the output tensor.
-        input: tensor to be applied op on.
-        pg: process group.
-        tensor_type: dtype of tensor.
-
-    Return: A list of communication results (int).
-    """
-    input_size_list_tensor = torch.tensor(
-        input_size_list, dtype=tensor_type, device=input.device
-    )
-    output_size_list_tensor = torch.empty(
-        output_size, dtype=tensor_type, device=input.device
-    )
-    dist.all_to_all_single(
-        output_size_list_tensor,
-        input_size_list_tensor,
-        group=pg,
-    )
-    return output_size_list_tensor.tolist()
-
-
-def _communicate_list_to_each_rank(
-    input_tensor_list, output_lists, input, pg, tensor_type=torch.int64
-):
-    """
-    In the circumstance of row-wise sharding of weight, we need to
-    communicate a list of input tensors to each rank. Because the
-    input could be a list of list, we need to first convert the list
-    to a tensor.
-
-    Args:
-        input_tensor_list: list of tensors to be sent to each rank.
-        output_lists: list of sizes to be obtained from each rank.
-        input: tensor to be applied op on.
-        pg: process group.
-        tensor_type: dtype of tensor.
-
-    Return: A list of communication results (tensors).
-    """
-    output_tensor_list = []
-    for output_list in output_lists:
-        output_tensor_list.append(
-            torch.empty(output_list, dtype=tensor_type, device=input.device)
-        )
-    dist.all_to_all(
-        output_tensor_list,
-        input_tensor_list,
-        group=pg,
-    )
-    return output_tensor_list
-
-
-def _handle_max_norm_col_wise(
-    max_norm,
-    norm_type,
-    local_shard,
-    input,
-    world_size,
-    pg,
-):
-    """
-    For col-wise sharding of weight, we need to aggregate the
-    norm across all ranks before we can perform the proper re-norm.
-    Note that, the max_norm logic is only applied to the embedding
-    indices that are looked up and not the whole shard.
-
-    Args:
-        max_norm: If given, each embedding vector with norm larger
-            than max_norm is renormalized to have norm max_norm.
-            Note: this will modify weight in-place.
-        norm_type: The p in the p-norm to compute for the max_norm option.
-        local_shard: col-wise shared local weight used for lookup.
-        input: tensor to be applied op to.
-        world_size: number of ranks.
-        pg: process group.
-
-    Return:
-        local_shard_norm_renormed: local_shard re-normed to max_norm if the norm is larger
-            than it.
-        gathered_inputs: list of inputs from all ranks.
-    """
-    norm_type = norm_type if norm_type is not None else 2.0
-    # allgather the inputs first.
-    gathered_inputs = [torch.zeros_like(input) for _ in range(world_size)]
-    dist.all_gather(gathered_inputs, input, group=pg)
-    unique_inp = torch.unique(torch.cat(gathered_inputs))
-    local_shard_sum = torch.sum(
-        torch.pow(torch.abs(local_shard), norm_type), dim=1, dtype=local_shard.dtype
-    )
-    # For col-wise sharding, we need to first aggregate the powered sum
-    # from each rank first and then calculate the norm.
-    dist.all_reduce(local_shard_sum, group=pg)
-    local_shard_norm = torch.pow(local_shard_sum, 1.0 / norm_type)
-    max_norm_tensor = torch.full(
-        (local_shard.size(0),),
-        float("inf"),
-        dtype=local_shard.dtype,
-        device=input.device,
-    )
-    max_norm_tensor[unique_inp] = max_norm
-    local_shard_t = local_shard.t().contiguous()
-    normalized_tensor = torch.where(
-        local_shard_norm > max_norm_tensor, max_norm_tensor, local_shard_norm
-    )
-    # Make sure divisor is not zero.
-    local_shard_norm[local_shard_norm == 0.0] = 1.0
-    local_shard_norm_renormed = (
-        torch.div(torch.mul(local_shard_t, normalized_tensor), local_shard_norm)
-        .t()
-        .contiguous()
-    )
-    return local_shard_norm_renormed, gathered_inputs
diff --git a/torch/distributed/_shard/sharded_tensor/_ops/default_tensor_ops.py b/torch/distributed/_shard/sharded_tensor/_ops/default_tensor_ops.py
index 2b8c2a4..8c23f1d 100644
--- a/torch/distributed/_shard/sharded_tensor/_ops/default_tensor_ops.py
+++ b/torch/distributed/_shard/sharded_tensor/_ops/default_tensor_ops.py
@@ -1,6 +1,12 @@
+import copy
 import torch
 from torch.distributed._shard.sharded_tensor import (
     sharded_op_impl,
+    Shard,
+    ShardedTensor,
+)
+from ._common import (
+    _register_sharded_op_on_local_shards,
 )
 
 
@@ -33,3 +39,58 @@
 
 # __reduce_ex__ to dispatch to get_state/set_state
 register_default_op(torch.Tensor.__reduce_ex__)
+
+def sharded_type_as_check(*args, **kwargs):
+    """
+    Perform extra checks for the sharded_type_as op such as the input needs to
+    be either a Tensor or ShardedTensor.
+
+    Args: same as ``torch.Tensor.type_as``.
+
+    Return: None
+    """
+    if len(args) < 2:
+        raise ValueError("Needs to give a tensor to cast type as!")
+    if not isinstance(args[1], torch.Tensor) and not isinstance(args[1], ShardedTensor):
+        raise ValueError("Needs to give a Tensor or ShardedTensor to cast type as!")
+
+
+def same_dtype(*args, **kwargs):
+    """
+    When the dtype is the same, return the original ShardedTensor.
+
+    Args: same as ``torch.Tensor.type_as``.
+
+    Return (bool): Whether to return early or not.
+    """
+    return args[0].dtype == args[1].dtype
+
+
+def sharded_type_as(args, kwargs, pg):
+    """
+    Handles ``__torch_function__`` dispatch for the ``torch.Tensor.type_as`` op.
+
+    Args: same as ``torch.Tensor.type_as``.
+
+    Return:
+        new_local_shards (List[Shard]): Local shards for the new sharded tensor.
+        st_meta (ShardedTensorMetadata): Metadata of the new sharded tensor.
+    """
+    st = args[0]
+    tensor = args[1]
+    if isinstance(tensor, ShardedTensor):
+        tensor = tensor.local_tensor()
+    new_local_shards = []
+    for shard in st.local_shards():
+        new_local_shards.append(Shard(shard.tensor.type_as(tensor), shard.metadata))
+    st_meta = copy.deepcopy(st._metadata)
+    st_meta.tensor_properties.dtype = tensor.dtype
+    return new_local_shards, st_meta
+
+
+_register_sharded_op_on_local_shards(
+    torch.Tensor.type_as,
+    early_stop_func=same_dtype,
+    extra_check=sharded_type_as_check,
+    customized_func=sharded_type_as,
+)
diff --git a/torch/distributed/_shard/sharded_tensor/_ops/elementwise_ops.py b/torch/distributed/_shard/sharded_tensor/_ops/elementwise_ops.py
index a53eebc..dc65277 100644
--- a/torch/distributed/_shard/sharded_tensor/_ops/elementwise_ops.py
+++ b/torch/distributed/_shard/sharded_tensor/_ops/elementwise_ops.py
@@ -4,7 +4,6 @@
     _register_sharded_op_on_local_shards,
 )
 
-
 _register_sharded_op_on_local_shards(torch.nn.functional.gelu)
 _register_sharded_op_on_local_shards(torch.nn.functional.relu)
 _register_sharded_op_on_local_shards(torch.nn.functional.dropout)
diff --git a/torch/distributed/_shard/sharded_tensor/_ops/math_ops.py b/torch/distributed/_shard/sharded_tensor/_ops/math_ops.py
index 205f016..6d3ed59 100644
--- a/torch/distributed/_shard/sharded_tensor/_ops/math_ops.py
+++ b/torch/distributed/_shard/sharded_tensor/_ops/math_ops.py
@@ -4,102 +4,80 @@
     ShardedTensor,
     sharded_op_impl
 )
-from torch.distributed._shard.sharding_spec import ChunkShardingSpec
 from torch.distributed._shard.replicated_tensor import ReplicatedTensor
 
 from torch.distributed._shard._utils import narrow_tensor
 
-from ._common import (
-    _chunk_sharding_spec_check,
-    _register_sharded_op_on_local_tensor,
-)
+def binary_math_op_impl(op, types, args=(), kwargs=None, pg=None):
+    """
+    Handles ``__torch_function__`` dispatch for the binary math ops
+    such as `torch.add`, `torch.mul`, `torch.div`, etc.
+    This method computes on ShardedTensor, or ShardedTensor op ReplicatedTensor
+    """
+    if len(args) != 2:
+        raise ValueError("Only support binary math op on ShardedTensor for now!")
+    lhs = args[0]
+    rhs = args[1]
+    # Validate types
+    if isinstance(lhs, ReplicatedTensor):
+        assert isinstance(rhs, ShardedTensor)
+        st_size = rhs.size()
+        st_meta = rhs.local_shards()[0].metadata
+        if st_size != lhs.size():
+            # try to broadcast replicated tensor
+            lhs = lhs.expand(st_size)
 
+        replica_part = narrow_tensor(lhs, st_meta)
+        res = op(replica_part, rhs.local_tensor())
+
+        return ShardedTensor._init_from_local_tensor(
+            res,
+            rhs.sharding_spec(),
+            rhs.size(),  # type: ignore[arg-type]
+            process_group=pg)
+
+    elif isinstance(rhs, ReplicatedTensor):
+        assert isinstance(lhs, ShardedTensor)
+        st_size = lhs.size()
+        st_meta = lhs.local_shards()[0].metadata
+        if st_size != rhs.size():
+            # try to broadcast replicated tensor
+            rhs = rhs.expand(st_size)
+
+        replica_part = narrow_tensor(rhs, st_meta)
+        res = op(lhs.local_tensor(), replica_part)
+        return ShardedTensor._init_from_local_tensor(
+            res,
+            lhs.sharding_spec(),
+            lhs.size(),  # type: ignore[arg-type]
+            process_group=pg)
+
+    elif isinstance(lhs, (int, float)):
+        assert isinstance(rhs, ShardedTensor)
+        res = op(lhs, rhs.local_tensor())
+        return ShardedTensor._init_from_local_tensor(
+            res,
+            rhs.sharding_spec(),
+            rhs.size(),  # type: ignore[arg-type]
+            process_group=pg)
+
+    elif isinstance(rhs, (int, float)):
+        assert isinstance(lhs, ShardedTensor)
+        res = op(lhs.local_tensor(), rhs)
+        return ShardedTensor._init_from_local_tensor(
+            res,
+            lhs.sharding_spec(),
+            lhs.size(),  # type: ignore[arg-type]
+            process_group=pg)
+    else:
+        raise RuntimeError(
+            f"torch function '{op.__name__}', with args: {args} and "
+            f"kwargs: {kwargs} not supported yet for ShardedTensor!")
 
 def register_math_op(op):
     @sharded_op_impl(op)
     def binary_math_op(types, args=(), kwargs=None, pg=None):
-        """
-        Handles ``__torch_function__`` dispatch for the binary math ops
-        such as `torch.add`, `torch.mul`, `torch.div`, etc.
-        This method computes on ShardedTensor, or ShardedTensor op ReplicatedTensor
-        """
-        if len(args) != 2:
-            raise ValueError("Only support binary math op on ShardedTensor for now!")
-        lhs = args[0]
-        rhs = args[1]
-        # Validate types
-        if isinstance(lhs, ShardedTensor) and isinstance(rhs, ShardedTensor):
-            lhs_spec = lhs.sharding_spec()
-            rhs_spec = rhs.sharding_spec()
-            if not isinstance(lhs_spec, ChunkShardingSpec) or not isinstance(rhs_spec, ChunkShardingSpec):
-                raise TypeError("Only ShardedTensor with ChunkShardingSpec supports"
-                                " two ShardedTensor together")
-
-            if lhs.size() == rhs.size() and lhs_spec.dim == rhs_spec.dim:
-                # perform local element-wise math op
-                res = op(lhs.local_tensor(), rhs.local_tensor())
-                return ShardedTensor._init_from_local_tensor(
-                    res,
-                    lhs_spec,
-                    lhs.size(),  # type: ignore[arg-type]
-                    process_group=pg)
-            else:
-                raise RuntimeError("Implicit broadcasting not supported yet!")
-
-        elif isinstance(lhs, ReplicatedTensor):
-            assert isinstance(rhs, ShardedTensor)
-            st_size = rhs.size()
-            st_meta = rhs.local_shards()[0].metadata
-            if st_size != lhs.size():
-                # try to broadcast replicated tensor
-                lhs = lhs.expand(st_size)
-
-            replica_part = narrow_tensor(lhs, st_meta)
-            res = op(replica_part, rhs.local_tensor())
-
-            return ShardedTensor._init_from_local_tensor(
-                res,
-                rhs.sharding_spec(),
-                rhs.size(),  # type: ignore[arg-type]
-                process_group=pg)
-
-        elif isinstance(rhs, ReplicatedTensor):
-            assert isinstance(lhs, ShardedTensor)
-            st_size = lhs.size()
-            st_meta = lhs.local_shards()[0].metadata
-            if st_size != rhs.size():
-                # try to broadcast replicated tensor
-                rhs = rhs.expand(st_size)
-
-            replica_part = narrow_tensor(rhs, st_meta)
-            res = op(lhs.local_tensor(), replica_part)
-            return ShardedTensor._init_from_local_tensor(
-                res,
-                lhs.sharding_spec(),
-                lhs.size(),  # type: ignore[arg-type]
-                process_group=pg)
-
-        elif isinstance(lhs, (int, float)):
-            assert isinstance(rhs, ShardedTensor)
-            res = op(lhs, rhs.local_tensor())
-            return ShardedTensor._init_from_local_tensor(
-                res,
-                rhs.sharding_spec(),
-                rhs.size(),  # type: ignore[arg-type]
-                process_group=pg)
-
-        elif isinstance(rhs, (int, float)):
-            assert isinstance(lhs, ShardedTensor)
-            res = op(lhs.local_tensor(), rhs)
-            return ShardedTensor._init_from_local_tensor(
-                res,
-                lhs.sharding_spec(),
-                lhs.size(),  # type: ignore[arg-type]
-                process_group=pg)
-        else:
-            raise RuntimeError(
-                f"torch function '{op.__name__}', with args: {args} and "
-                f"kwargs: {kwargs} not supported yet for ShardedTensor!")
+        return binary_math_op_impl(op, types, args, kwargs, pg)
 
 binary_ops = [
     # add
@@ -126,70 +104,3 @@
 
 for op in binary_ops:
     register_math_op(op)
-
-
-def sharded_bmm_check(*args, **kwargs):
-    """
-    Perform extra checks for the sharded_bmm op, for example, st2 needs to
-    be a sharded tensor and both tensors need to sharded by dim 0, etc.
-
-    Args: same as ``torch.bmm``.
-
-    Return: None
-    """
-    if len(args) < 2:
-        raise TypeError("Needs two tensors to perform torch.bmm.")
-    st = args[0]
-    st2 = args[1]
-    # Validate types
-    if not isinstance(st2, ShardedTensor):
-        raise TypeError("st2 needs to be a ShardedTensor for torch.bmm.")
-    _chunk_sharding_spec_check(st2.sharding_spec(), torch.bmm)
-    if st.dim() != 3 or st2.dim() != 3:
-        raise TypeError("both st and st2 need to be a 3D ShardedTensor")
-    if (
-        st.sharding_spec().dim != st2.sharding_spec().dim  # type: ignore[attr-defined]
-        or st.sharding_spec().dim != 0
-    ):
-        raise NotImplementedError(
-            "Only support performing bmm on tensors sharded on dim 0 now."
-        )
-    if st.sharding_spec().placements != st2.sharding_spec().placements:  # type: ignore[attr-defined]
-        raise NotImplementedError(
-            "Both st and st2 need to have same placements for bmm."
-        )
-
-
-def sharded_bmm(args, kwargs, pg):
-    """
-    Handles ``__torch_function__`` dispatch for the sharded_bmm op.
-
-    Warning: For now we only supports the case when both tensors are sharded
-             by dim 0 so that no communication is needed.
-
-    Args: same as ``torch.bmm``.
-
-    Return:
-        local_tensor (Tensor): New local tensor to build the sharded tensor.
-        sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`):
-            sharding spec of the new sharded tensor.
-        new_st_size (torch.Size): Size of the new sharded tensor.
-    """
-    st = args[0]
-    st2 = args[1]
-    local_tensor = torch.bmm(st.local_tensor(), st2.local_tensor())
-    new_st_size = (*st.size()[:-1], st2.size(-1))
-    return local_tensor, st.sharding_spec(), new_st_size
-
-
-_register_sharded_op_on_local_tensor(
-    torch.Tensor.bmm,
-    extra_check=sharded_bmm_check,
-    customized_func=sharded_bmm,
-)
-
-_register_sharded_op_on_local_tensor(
-    torch.bmm,
-    extra_check=sharded_bmm_check,
-    customized_func=sharded_bmm,
-)
diff --git a/torch/distributed/_shard/sharded_tensor/api.py b/torch/distributed/_shard/sharded_tensor/api.py
index e66526a..283dfb7 100644
--- a/torch/distributed/_shard/sharded_tensor/api.py
+++ b/torch/distributed/_shard/sharded_tensor/api.py
@@ -17,6 +17,10 @@
 from torch.distributed import rpc
 from torch.distributed import distributed_c10d
 import torch.distributed._shard.sharding_spec as shard_spec
+from torch.distributed._shard.sharding_spec.api import (
+    _dispatch_custom_op,
+    _has_custom_op,
+)
 from torch.distributed._shard.sharding_spec._internals import (
     check_tensor,
     validate_non_overlapping_shards_metadata,
@@ -39,7 +43,7 @@
 _sharded_tensor_map: Dict[int, 'weakref.ReferenceType[ShardedTensor]'] = {}
 
 # Custom sharded ops
-_SHARDED_OPS: Dict[str, Callable] = {}
+_SHARDED_OPS: Dict[Callable, Callable] = {}
 def _register_sharded_op(op, func):
     from inspect import signature
     if len(signature(func).parameters) != 4:
@@ -524,6 +528,7 @@
             sharded_tensor_metadata,
             process_group=process_group,
             init_rrefs=init_rrefs,
+            sharding_spec=sharding_spec,
         )
 
     @classmethod
@@ -533,6 +538,7 @@
         sharded_tensor_metadata: ShardedTensorMetadata,
         process_group=None,
         init_rrefs=False,
+        sharding_spec=None,
     ) -> "ShardedTensor":
         """
         Initialize a ShardedTensor with local shards and a global
@@ -615,7 +621,10 @@
 
         # done validation, add local_shards
         sharded_tensor._local_shards = local_shards
-        sharded_tensor._sharding_spec = shard_spec._infer_sharding_spec_from_shards_metadata(shards_metadata)
+        if sharding_spec is None:
+            sharded_tensor._sharding_spec = shard_spec._infer_sharding_spec_from_shards_metadata(shards_metadata)
+        else:
+            sharded_tensor._sharding_spec = sharding_spec
 
         # run post initialization, i.e. map registration, rpc initialization
         sharded_tensor._post_init()
@@ -741,15 +750,26 @@
 
     @classmethod
     def __torch_function__(cls, func, types, args=(), kwargs=None):
-        if func in _SHARDED_OPS:
-            # Find ShardedTensor instance to get process_group.
-            for arg in args:
-                if isinstance(arg, ShardedTensor):
-                    return _SHARDED_OPS[func](types, args, kwargs, arg._process_group)
+        def dispatch(st: ShardedTensor, func: Callable):
+            # Dispatch to custom sharding spec op if it has one.
+            if _has_custom_op(st._sharding_spec, func):
+                return _dispatch_custom_op(st._sharding_spec, func, types, args, kwargs)
 
-            for kwarg in kwargs.values():
-                if isinstance(kwarg, ShardedTensor):
-                    return _SHARDED_OPS[func](types, args, kwargs, kwarg._process_group)
+            if func in _SHARDED_OPS:
+                return _SHARDED_OPS[func](types, args, kwargs, st._process_group)
+
+            raise RuntimeError(
+                f"torch function '{func.__name__}', with args: {args} and "
+                f"kwargs: {kwargs} not supported for ShardedTensor!")
+
+        # Find ShardedTensor instance to get process_group and sharding_spec.
+        for arg in args:
+            if isinstance(arg, ShardedTensor):
+                return dispatch(arg, func)
+
+        for kwarg in kwargs.values():
+            if isinstance(kwarg, ShardedTensor):
+                return dispatch(kwarg, func)
 
         raise RuntimeError(
             f"torch function '{func.__name__}', with args: {args} and "
@@ -994,6 +1014,12 @@
     def __rtruediv__(self, other):
         return handle_torch_function(torch.Tensor.__rdiv__, (self, other), self, other)
 
+    def tanh(self):
+        return handle_torch_function(torch.Tensor.tanh, (self,), self)
+
+    def __getitem__(self, key):
+        return handle_torch_function(torch.Tensor.__getitem__, (self, key), self, key)
+
     @dataclass
     class ProcessGroupState:
         """
diff --git a/torch/distributed/_shard/sharding_spec/__init__.py b/torch/distributed/_shard/sharding_spec/__init__.py
index 6f2e5c7..e356295 100644
--- a/torch/distributed/_shard/sharding_spec/__init__.py
+++ b/torch/distributed/_shard/sharding_spec/__init__.py
@@ -1,10 +1,12 @@
 from .api import (
-    ChunkShardingSpec,
     DevicePlacementSpec,
     EnumerableShardingSpec,
     PlacementSpec,
     ShardingSpec,
     _infer_sharding_spec_from_shards_metadata,
 )
+from .chunk_sharding_spec import (
+    ChunkShardingSpec,
+)
 
 from torch.distributed._shard.metadata import ShardMetadata
diff --git a/torch/distributed/_shard/sharding_spec/api.py b/torch/distributed/_shard/sharding_spec/api.py
index a42add7..32c07cd 100644
--- a/torch/distributed/_shard/sharding_spec/api.py
+++ b/torch/distributed/_shard/sharding_spec/api.py
@@ -1,7 +1,7 @@
 from abc import ABC, abstractmethod
 from dataclasses import dataclass
-from typing import List, Union
-from typing import TYPE_CHECKING
+import functools
+from typing import Callable, Dict, List, Type, TYPE_CHECKING
 
 import torch
 
@@ -13,14 +13,7 @@
 )
 from torch.distributed._shard.metadata import ShardMetadata
 
-from torch.distributed._shard.sharded_tensor.utils import (
-    _parse_and_validate_remote_device
-)
-
-import torch.distributed as dist
 import torch.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta
-from torch.distributed._shard.sharded_tensor.shard import Shard
-from torch.distributed._shard._utils import narrow_tensor
 
 if TYPE_CHECKING:
     # Only include ShardedTensor when do type checking, exclude it
@@ -51,7 +44,7 @@
         if not isinstance(self.device, torch.distributed._remote_device):
             self.device = torch.distributed._remote_device(self.device)
 
-class ShardingSpec(object):
+class ShardingSpec(ABC):
     """
     Base class representing sharding specifications.
     """
@@ -92,174 +85,63 @@
             A :class:`ShardedTensor` sharded from the given tensor.
         """
 
-@dataclass
-class ChunkShardingSpec(ShardingSpec):
+# Ops customized for a particular ShardingSpec.
+CUSTOM_SHARDING_SPEC_OPS: Dict[str, Dict[Callable, Callable]] = {}
+
+def _register_custom_op(sharding_spec_cls: Type, op: Callable, func: Callable):
     """
-    This is a type of PlacementSpec that defines the placement as being sharded
-    across multiple devices. In particular, it represents sharding a Tensor
-    along a single dimension into equal chunks (similar to :meth:`torch.chunk`).
-
-    The semantics of how a tensor is partitioned is inline with
-    :meth:`torch.chunk`, where ``dim`` in torch.chunk corresponds to the
-    specified ``dim`` and ``chunks`` in torch.chunk is the number of elements
-    in the placement specified.
-
+    Allows registration of a custom op for ShardedTensor to enable
+    custom optimizations for a particular ShardingSpec.
     Args:
-        dim (int or str):
-            The dimension to shard on, could be an integer representing the
-            dimension or a string in case of named tensors where dimensions are
-            named. Note that named tensor support is not added yet.
-        placement(List[Union[_remote_device, str]]):
-            Specifies the placement of each shard of the Tensor. The size of
-            the list represents the number of shards to be created. This could
-            be a list of
-            :class:`torch.distributed._remote_device`'s. This list
-            could also contain a string which represents remote
-            device as accepted by
-            :class:`torch.distributed._remote_device`
+        sharding_spec(type): The ShardingSpec for which we need to add this custom op.
+        op(Callable): The op to override (ex: torch.bmm)
+        func(Callable): The custom implementation for ``op``
     """
+    from inspect import signature
+    if len(signature(func).parameters) != 3:
+        raise TypeError(
+            f'Custom sharded op function expects signature: '
+            f'(types, args, kwargs), but received '
+            f'signature: {signature(func)}')
 
-    ShardingDim = Union[int, str]
+    global CUSTOM_SHARDING_SPEC_OPS
+    class_name = sharding_spec_cls.__qualname__
+    if class_name not in CUSTOM_SHARDING_SPEC_OPS:
+        CUSTOM_SHARDING_SPEC_OPS[class_name] = {}
+    CUSTOM_SHARDING_SPEC_OPS[class_name][op] = func
 
-    dim: ShardingDim
-    placements: List[Union[torch.distributed._remote_device, str]]
+def _has_custom_op(sharding_spec, op):
+    """
+    Returns whether or not the ShardingSpec has a custom op implementation.
+    """
+    class_name = type(sharding_spec).__qualname__
+    return class_name in CUSTOM_SHARDING_SPEC_OPS and op in CUSTOM_SHARDING_SPEC_OPS[class_name]
 
-    def __post_init__(self):
-        self._verify_dim(self.dim)
-        for i, remote_device in enumerate(self.placements):
-            if not isinstance(remote_device, torch.distributed._remote_device):
-                self.placements[i] = torch.distributed._remote_device(remote_device)
+def _dispatch_custom_op(sharding_spec, op: Callable, types, args, kwargs):
+    """
+    Calls the custom op for this ShardingSpec if it exists.
+    """
+    class_name = type(sharding_spec).__qualname__
+    if not _has_custom_op(sharding_spec, op):
+        raise RuntimeError(f'Custom op: {op} not registered for {class_name}')
+    func = CUSTOM_SHARDING_SPEC_OPS[class_name][op]
+    return func(types, args, kwargs)
 
-    @staticmethod
-    def _verify_dim(dim):
-        # Validate the sharding spec.
-        # TODO: support named dimension
-        if isinstance(dim, str):
-            raise NotImplementedError(
-                "ChunkShardingSpec does not support named dimension yet!"
-            )
+def custom_sharding_spec_op(sharding_spec_class, func):
+    """
+    Decorator to allow custom registration of ops.
+    Args:
+        sharding_spec_class(type): The ShardingSpec for which we need to add this custom op.
+        func(Callable): The op to override (ex: torch.bmm)
+    """
+    def decorator_sharded_func(wrapped_func):
+        _register_custom_op(sharding_spec_class, func, wrapped_func)
 
-        if not isinstance(dim, int):
-            raise ValueError(
-                f"Sharding dim needs to be an integer, found: {dim}"
-            )
-
-    def build_metadata(self,
-                       tensor_sizes: torch.Size,
-                       tensor_properties: sharded_tensor_meta.TensorProperties,
-                       ) -> sharded_tensor_meta.ShardedTensorMetadata:
-        tensor_num_dim = len(tensor_sizes)
-
-        self._verify_dim(self.dim)
-        if self.dim >= tensor_num_dim or self.dim < -tensor_num_dim:  # type: ignore[operator]
-            raise ValueError(f"Invalid sharding dim: {self.dim}")
-
-        shards_metadata = []
-        sharding_dim_size = tensor_sizes[self.dim]  # type: ignore[index]
-        chunks = len(self.placements)
-        split_size = get_split_size(sharding_dim_size, chunks)
-        for idx, placement in enumerate(self.placements):
-            # generate ShardMetadata for each placement device
-            chunked_dim_size = get_chunked_dim_size(sharding_dim_size, split_size, idx)
-            if chunked_dim_size > 0:
-                shard_size = list(tensor_sizes)
-                current_offsets = [0] * tensor_num_dim
-                current_offsets[self.dim] = split_size * idx  # type: ignore[index]
-                shard_size[self.dim] = chunked_dim_size  # type: ignore[index]
-
-                shard_metadata = ShardMetadata(
-                    shard_offsets=current_offsets,
-                    shard_sizes=shard_size,
-                    placement=placement,
-                )
-                shards_metadata.append(shard_metadata)
-
-        return sharded_tensor_meta.ShardedTensorMetadata(
-            shards_metadata,
-            tensor_sizes,
-            tensor_properties
-        )
-
-
-    def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> "ShardedTensor":
-        # relative imports to avoid circular dependency
-        from torch.distributed._shard.sharded_tensor import (
-            ShardedTensor
-        )
-        tensor_properties = sharded_tensor_meta.TensorProperties(
-            dtype=tensor.dtype,
-            layout=tensor.layout,
-            requires_grad=tensor.requires_grad,
-            memory_format=torch.contiguous_format,
-            pin_memory=tensor.is_pinned()
-        )
-        current_rank = dist.get_rank(process_group)
-        tensor_meta = self.build_metadata(tensor.size(), tensor_properties)
-        local_shards = []
-        local_tensor = None
-        local_metadata = None
-        tensors_to_scatter = [None] * dist.get_world_size(process_group)
-
-        sharding_dim_size = tensor.size()[self.dim]  # type: ignore[index]
-        chunks = len(self.placements)
-        split_size = get_split_size(sharding_dim_size, chunks)
-        scatter_shape = list(tensor.size())
-        scatter_shape[self.dim] = split_size  # type: ignore[index]
-
-        for shard_meta in tensor_meta.shards_metadata:
-            rank, device = _parse_and_validate_remote_device(process_group, shard_meta.placement)
-            if current_rank == src_rank:
-                # Reshape to get shard for this rank and we don't want autograd
-                # recording here for the narrow op and 'local_shard' should be a
-                # leaf variable in the autograd graph.
-                narrowed_tensor = narrow_tensor(tensor, shard_meta)
-                if shard_meta.shard_sizes[self.dim] < split_size:  # type: ignore[index]
-                    # for the last shard that might be smaller to other shards
-                    # resize the narrowed tensor to the same size and use it for
-                    # the scatter collective as dist.scatter requires same size
-                    # inputs on every rank
-                    tensor_to_scatter = narrowed_tensor.detach().clone().resize_(scatter_shape)
-                else:
-                    tensor_to_scatter = narrowed_tensor.detach().clone().contiguous()
-
-                tensors_to_scatter[rank] = tensor_to_scatter
-
-            if current_rank == rank:
-                local_tensor = torch.empty(
-                    scatter_shape, dtype=tensor.dtype, layout=tensor.layout, device=device)
-                local_metadata = shard_meta
-
-        # each rank should have local_tensor and local_metadata initialized if we build
-        # the metadata list in a correct way.
-        assert local_tensor is not None
-        assert local_metadata is not None
-
-        # Scatter the shards to all ranks in the pg
-        dist.scatter(
-            local_tensor,
-            scatter_list=tensors_to_scatter if current_rank == src_rank else None,
-            src=src_rank,
-            group=process_group
-        )
-
-        if list(local_tensor.size()) != local_metadata.shard_sizes:
-            # detach again after receiving to ensure local shards remain a leaf node
-            local_tensor = local_tensor.resize_(local_metadata.shard_sizes).detach()
-
-        # Sync requires_grad to local_shard.
-        local_tensor.requires_grad = tensor.requires_grad
-
-        local_shards.append(Shard(tensor=local_tensor, metadata=local_metadata))
-
-        st = ShardedTensor._init_from_local_shards_and_global_metadata(
-            local_shards,
-            tensor_meta,
-            process_group=process_group)
-
-        # Manually set sharding_spec
-        st._sharding_spec = self
-
-        return st
+        @functools.wraps(wrapped_func)
+        def wrapper(*args, **kwargs):
+            return wrapped_func(*args, **kwargs)
+        return wrapper
+    return decorator_sharded_func
 
 
 @dataclass
@@ -353,6 +235,8 @@
         placements = [
             x for _, x in sorted(zip(chunk_offset_list, placements), key=lambda e: e[0])
         ]
+
+        from .chunk_sharding_spec import ChunkShardingSpec
         chunk_spec = ChunkShardingSpec(
             dim=chunk_sharding_dim,
             placements=placements,
diff --git a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py
new file mode 100644
index 0000000..479eea2
--- /dev/null
+++ b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec.py
@@ -0,0 +1,193 @@
+from dataclasses import dataclass
+import torch
+import torch.distributed._shard.sharded_tensor.metadata as sharded_tensor_meta
+from torch.distributed._shard.metadata import ShardMetadata
+from torch.distributed._shard.sharded_tensor.shard import Shard
+from torch.distributed._shard.sharded_tensor.utils import (
+    _parse_and_validate_remote_device
+)
+from torch.distributed._shard._utils import narrow_tensor
+import torch.distributed as dist
+from typing import List, Union, TYPE_CHECKING
+from ._internals import (
+    get_chunked_dim_size,
+    get_split_size,
+)
+
+from .api import ShardingSpec
+
+if TYPE_CHECKING:
+    # Only include ShardedTensor when do type checking, exclude it
+    # from run-time to resolve circular dependency.
+    from torch.distributed._shard.sharded_tensor import ShardedTensor
+
+@dataclass
+class ChunkShardingSpec(ShardingSpec):
+    """
+    This is a type of PlacementSpec that defines the placement as being sharded
+    across multiple devices. In particular, it represents sharding a Tensor
+    along a single dimension into equal chunks (similar to :meth:`torch.chunk`).
+
+    The semantics of how a tensor is partitioned is inline with
+    :meth:`torch.chunk`, where ``dim`` in torch.chunk corresponds to the
+    specified ``dim`` and ``chunks`` in torch.chunk is the number of elements
+    in the placement specified.
+
+    Args:
+        dim (int or str):
+            The dimension to shard on, could be an integer representing the
+            dimension or a string in case of named tensors where dimensions are
+            named. Note that named tensor support is not added yet.
+        placement(List[Union[_remote_device, str]]):
+            Specifies the placement of each shard of the Tensor. The size of
+            the list represents the number of shards to be created. This could
+            be a list of
+            :class:`torch.distributed._remote_device`'s. This list
+            could also contain a string which represents remote
+            device as accepted by
+            :class:`torch.distributed._remote_device`
+    """
+
+    ShardingDim = Union[int, str]
+
+    dim: ShardingDim
+    placements: List[Union[torch.distributed._remote_device, str]]
+
+    def __post_init__(self):
+        self._verify_dim(self.dim)
+        for i, remote_device in enumerate(self.placements):
+            if not isinstance(remote_device, torch.distributed._remote_device):
+                self.placements[i] = torch.distributed._remote_device(remote_device)
+
+    @staticmethod
+    def _verify_dim(dim):
+        # Validate the sharding spec.
+        # TODO: support named dimension
+        if isinstance(dim, str):
+            raise NotImplementedError(
+                "ChunkShardingSpec does not support named dimension yet!"
+            )
+
+        if not isinstance(dim, int):
+            raise ValueError(
+                f"Sharding dim needs to be an integer, found: {dim}"
+            )
+
+    def build_metadata(self,
+                       tensor_sizes: torch.Size,
+                       tensor_properties: sharded_tensor_meta.TensorProperties,
+                       ) -> sharded_tensor_meta.ShardedTensorMetadata:
+        tensor_num_dim = len(tensor_sizes)
+
+        self._verify_dim(self.dim)
+        if self.dim >= tensor_num_dim or self.dim < -tensor_num_dim:  # type: ignore[operator]
+            raise ValueError(f"Invalid sharding dim: {self.dim}")
+
+        shards_metadata = []
+        sharding_dim_size = tensor_sizes[self.dim]  # type: ignore[index]
+        chunks = len(self.placements)
+        split_size = get_split_size(sharding_dim_size, chunks)
+        for idx, placement in enumerate(self.placements):
+            # generate ShardMetadata for each placement device
+            chunked_dim_size = get_chunked_dim_size(sharding_dim_size, split_size, idx)
+            if chunked_dim_size > 0:
+                shard_size = list(tensor_sizes)
+                current_offsets = [0] * tensor_num_dim
+                current_offsets[self.dim] = split_size * idx  # type: ignore[index]
+                shard_size[self.dim] = chunked_dim_size  # type: ignore[index]
+
+                shard_metadata = ShardMetadata(
+                    shard_offsets=current_offsets,
+                    shard_sizes=shard_size,
+                    placement=placement,
+                )
+                shards_metadata.append(shard_metadata)
+
+                # current_offsets[self.dim] += chunked_dim_size  # type: ignore[index]
+
+        return sharded_tensor_meta.ShardedTensorMetadata(
+            shards_metadata,
+            tensor_sizes,
+            tensor_properties
+        )
+
+
+    def shard(self, tensor: torch.Tensor, src_rank: int = 0, process_group=None) -> "ShardedTensor":
+        # relative imports to avoid circular dependency
+        from torch.distributed._shard.sharded_tensor import (
+            ShardedTensor
+        )
+        tensor_properties = sharded_tensor_meta.TensorProperties(
+            dtype=tensor.dtype,
+            layout=tensor.layout,
+            requires_grad=tensor.requires_grad,
+            memory_format=torch.contiguous_format,
+            pin_memory=tensor.is_pinned()
+        )
+        current_rank = dist.get_rank(process_group)
+        tensor_meta = self.build_metadata(tensor.size(), tensor_properties)
+        local_shards = []
+        local_tensor = None
+        local_metadata = None
+        tensors_to_scatter = [None] * dist.get_world_size(process_group)
+
+        sharding_dim_size = tensor.size()[self.dim]  # type: ignore[index]
+        chunks = len(self.placements)
+        split_size = get_split_size(sharding_dim_size, chunks)
+        scatter_shape = list(tensor.size())
+        scatter_shape[self.dim] = split_size  # type: ignore[index]
+
+        for shard_meta in tensor_meta.shards_metadata:
+            rank, device = _parse_and_validate_remote_device(process_group, shard_meta.placement)
+            if current_rank == src_rank:
+                # Reshape to get shard for this rank and we don't want autograd
+                # recording here for the narrow op and 'local_shard' should be a
+                # leaf variable in the autograd graph.
+                narrowed_tensor = narrow_tensor(tensor, shard_meta)
+                if shard_meta.shard_sizes[self.dim] < split_size:  # type: ignore[index]
+                    # for the last shard that might be smaller to other shards
+                    # resize the narrowed tensor to the same size and use it for
+                    # the scatter collective as dist.scatter requires same size
+                    # inputs on every rank
+                    tensor_to_scatter = narrowed_tensor.detach().clone().resize_(scatter_shape)
+                else:
+                    tensor_to_scatter = narrowed_tensor.detach().clone().contiguous()
+
+                tensors_to_scatter[rank] = tensor_to_scatter
+
+            if current_rank == rank:
+                local_tensor = torch.empty(
+                    scatter_shape, dtype=tensor.dtype, layout=tensor.layout, device=device)
+                local_metadata = shard_meta
+
+        # each rank should have local_tensor and local_metadata initialized if we build
+        # the metadata list in a correct way.
+        assert local_tensor is not None
+        assert local_metadata is not None
+
+        # Scatter the shards to all ranks in the pg
+        dist.scatter(
+            local_tensor,
+            scatter_list=tensors_to_scatter if current_rank == src_rank else None,
+            src=src_rank,
+            group=process_group
+        )
+
+        if list(local_tensor.size()) != local_metadata.shard_sizes:
+            # detach again after receiving to ensure local shards remain a leaf node
+            local_tensor = local_tensor.resize_(local_metadata.shard_sizes).detach()
+
+        # Sync requires_grad to local_shard.
+        local_tensor.requires_grad = tensor.requires_grad
+
+        local_shards.append(Shard(tensor=local_tensor, metadata=local_metadata))
+
+        st = ShardedTensor._init_from_local_shards_and_global_metadata(
+            local_shards,
+            tensor_meta,
+            process_group=process_group)
+
+        # Manually set sharding_spec
+        st._sharding_spec = self
+
+        return st
diff --git a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__init__.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/__init__.py
diff --git a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/_common.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/_common.py
new file mode 100644
index 0000000..c2ef8fe
--- /dev/null
+++ b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/_common.py
@@ -0,0 +1,447 @@
+# coding=utf-8
+
+from typing import List
+
+import torch
+import torch.distributed as dist
+import torch.distributed._shard.sharding_spec as shard_spec
+from torch.distributed._shard.sharded_tensor._ops._common import _sharded_op_common
+from torch.distributed._shard.sharded_tensor import (
+    sharded_op_impl,
+    ShardedTensor,
+)
+from torch.distributed._shard.sharding_spec._internals import (
+    get_split_size,
+    get_chunked_dim_size,
+)
+from torch.distributed.nn.functional import (
+    all_gather,
+    all_to_all_single,
+)
+
+
+def _chunk_sharding_spec_check(spec, op):
+    """
+    For the given op implementation check if the sharding spec is ChunkShardingSpec.
+    """
+    if not isinstance(spec, shard_spec.ChunkShardingSpec):
+        raise NotImplementedError(
+            f"Only ChunkShardingSpec supported for '{op.__name__}'."
+        )
+
+def _register_sharded_op_on_local_tensor(
+    op, early_stop_func=None, extra_check=None, customized_func=None
+):
+    """
+    Handles ``__torch_function__`` dispatch for ops which are performed on
+    the single local tensor of the sharded tensor such as op like
+    ``torch.nn.functional.softmax`` or ``torch.Tensor.view``.
+
+    For more complicated ops, a customized func can be used to generate
+    the new local tensor, sharding spec and sharded tensor size.
+
+    Args:
+        op: The op to be registered and applied to all shards of the st.
+        early_stop_func (Callable, optional): the func for early stop.
+            Default: if ``None``, no early stop.
+        extra_check (Callable, optional): the func for extra condition check.
+            Default: if ``None``, no extra check.
+        customized_func (Callable, optional): the func for customized logic
+            to generate the new local tensor, sharding spec and sharded tensor size.
+            Default: if ``None``, we simply lower to the real op call with
+                the single local tensor of the st.
+
+    Return:
+        func (Callable): registered implementation for sharded op for
+        ``__torch_function__`` dispatch.
+    """
+    @sharded_op_impl(op)
+    @_sharded_op_common(op, early_stop_func, extra_check)
+    def sharded_tensor_op_on_local_shards(types, args=(), kwargs=None, pg=None):
+        st = args[0]
+        sharding_spec = st.sharding_spec()
+        _chunk_sharding_spec_check(sharding_spec, op)
+        if len(st.local_shards()) != 1:
+            raise TypeError(
+                f"torch function '{op.__name__}', with args: {args} and "
+                f"kwargs: {kwargs} only supported for single local tensor!"
+            )
+        st_size = st.size()
+        if customized_func:
+            local_tensor, sharding_spec, st_size = customized_func(args, kwargs, pg)
+        else:
+            args = (st.local_tensor(), *args[1:])
+            local_tensor = op(*args, **kwargs)
+        return ShardedTensor._init_from_local_tensor(
+            local_tensor.contiguous(),
+            sharding_spec,
+            st_size,  # type: ignore[arg-type]
+            process_group=pg,
+            init_rrefs=st._init_rrefs,
+        )
+
+
+def _handle_col_wise_sharding_base(
+    op_func,
+    col_dim,
+    input,
+    world_size,
+    weight,
+    local_shard,
+    pg,
+    gathered_inputs=None,
+    mode=None,
+    gathered_per_sample_weights=None,
+    gathered_offsets=None,
+    padding_idx=None,
+):
+    """
+    For col-wise sharding of weight, lots of logic are common.
+    So we extract the common logic and put in this function:
+    Step 1. To get input from each rank and
+    Step 2. To perform the op on the concatenated tensor.
+    Step 3. To distribute results to each rank with col rearrangement.
+    Step 4. To concatenate all results from all ranks.
+
+    Args:
+        op_func: operator which is applied to the input tensor.
+        col_dim: dim of result tensor after the operation.
+        input: tensor to be applied op on.
+        world_size: number of ranks.
+        weight: shareded weight tensor.
+        local_shard: col-wise sharded weight tensor.
+        pg: process group.
+        gathered_inputs: list of inputs from all ranks. If specified, we
+            don't need to communicate with each rank any more.
+        mode: aggregation mode of EmbeddingBag.
+        gathered_per_sample_weights: per_sample_weights across all ranks.
+        gathered_offsets: offsets across all ranks.
+        padding_idx: If specified, the entries at padding_idx do
+            not contribute to the gradient; therefore, the embedding
+            vector at padding_idx is not updated during training,
+            i.e. it remains as a fixed “pad”.
+            Note that the embedding vector at padding_idx is
+            excluded from the reduction.
+
+    Return: final result of input being applied with the op.
+    """
+    if gathered_inputs is None:
+        # allgather the inputs first.
+        gathered_inputs = all_gather(input, group=pg)
+
+    # run the operator's function for all the inputs.
+    results = []
+    for i, inp in enumerate(gathered_inputs):
+        if op_func == torch.nn.functional.embedding_bag:
+            result = op_func(
+                inp,
+                local_shard,
+                offsets=gathered_offsets[i] if gathered_offsets is not None else None,
+                mode=mode,
+                per_sample_weights=gathered_per_sample_weights[i]
+                if gathered_per_sample_weights is not None
+                else None,
+                padding_idx=padding_idx,
+            )
+        elif op_func == torch.nn.functional.embedding:
+            result = op_func(
+                inp,
+                local_shard,
+                padding_idx=padding_idx,
+            )
+        else:
+            result = op_func(inp, local_shard)
+        results.append(torch.transpose(result, 0, col_dim))
+
+    # Distribute results to each rank with col rearrangement.
+    output = _result_distribute_with_col_rearrange(
+        results, input, world_size, weight, pg
+    )
+
+    # transpose the output and return result.
+    return torch.transpose(output, 0, col_dim)
+
+
+def _result_distribute_with_col_rearrange(
+    results, input, world_size, weight, pg
+):
+    """
+    For col-wise sharding of weight, we need to distribute
+    results to each rank. We do them in this function.
+    Note that, if the index in the Sharding Spec is not equal to
+    the rank number, we need to do the rearrangement based on the
+    order given by the Sharding Spec (placement).
+
+    Args:
+        results: results from ops applied to inputs from all ranks.
+            We need to distribute them back to their original ranks.
+        input: tensor to be applied op to.
+        world_size: number of ranks.
+        weight: shareded weight tensor.
+        pg: process group.
+
+    Return: column rearranged result.
+    """
+    # Process results and outputs for all2all.
+    sharding_dim = weight._sharding_spec.dim
+    sharding_dim_size = weight.size(sharding_dim)
+    dims = list(results[0].size())
+    dims[0] = sharding_dim_size
+    output = torch.empty(*dims, device=input.device)
+    combined_results = torch.cat(results)
+
+    # Compute output splits
+    split_size = get_split_size(sharding_dim_size, world_size)
+    output_split_sizes = [0] * world_size
+    for idx, placement in enumerate(weight._sharding_spec.placements):
+        output_split_sizes[placement.rank()] = get_chunked_dim_size(
+            sharding_dim_size, split_size, idx
+        )
+
+    # distribute the outputs using all2all.
+    output = all_to_all_single(
+        output, combined_results, output_split_sizes=output_split_sizes, group=pg
+    )
+
+    # Check if we need to rearrange columns appropriately for output.
+    rearrange_columns = any(
+        [
+            idx != placement.rank()
+            for idx, placement in enumerate(weight._sharding_spec.placements)
+        ]
+    )
+    if not rearrange_columns:
+        return output
+
+    indices = []
+    for placement in weight._sharding_spec.placements:
+        dim_size = output_split_sizes[placement.rank()]
+        start = sum(
+            [
+                split_size if i < placement.rank() else 0
+                for i, split_size in enumerate(output_split_sizes)
+            ]
+        )
+        indices += list(range(start, start + dim_size))
+
+    return output.index_select(0, torch.tensor(indices, device=output.device))
+
+
+def _handle_row_wise_lookup_distribute(
+    input_sorted, input, world_size, weight, rank, padding_idx
+):
+    """
+    In the circumstance of row-wise sharding of weight, we need to distribute
+    the sorted lookup IDs of embedding/embeddingBag to each rank.
+    If the index in the placement is not equal to the rank number, we need to
+    do the rearrangement based on the order given by the Sharding Spec (placement).
+
+    In addition, we do two things for padding_idx. The first thing is to only
+    set it if it's within the range of the current rank and the other thing
+    is to do the modularization of it by sharded_dim_size_max.
+
+    Args:
+        input_sorted: sorted lookup IDs of embedding/embeddingBag.
+        input: tensor to be applied op on.
+        world_size: number of ranks.
+        weight: shareded weight tensor.
+        rank: # of cuda process.
+        padding_idx: If specified, the entries at padding_idx do
+            not contribute to the gradient and reduction.
+
+    Return:
+        input_sorted: sorted lookup IDs of embedding/embeddingBag
+            Rearrangement performed if it is needed.
+        input_split_sizes: size of IDs to be assigned to each rank.
+        sharded_dim_size_max: the max size of the row each rank gets.
+        input_split_rearrange_indices: indices of row rearrangement.
+        rearrange_indices_1d_second_order: reverse indices of row
+            rearrangement, which will be used to restore the original
+            order.
+        padding_idx: Same as input if padding_idx is within the range
+            of the given rank; otherwise, None is returned. It is
+            also modularized by sharded_dim_size_max.
+    """
+    # Decide which rank the input goes to by check the sharding range.
+    split_size = get_split_size(weight.size(0), world_size)
+    rearrange_rows = False
+    indices_flatten = None
+    input_split_sizes: List[int] = [0] * world_size
+    input_split_start_indices: List[int] = [0] * world_size
+    start_row_idx_rank = None
+    end_row_idx_rank = None
+    # When we do the chunk split, we always ensure the first N - 1 chunks get max out
+    # and then the Nth chunk gets the rest. So input_split_sizes like [3, 3, 3, 4]
+    # are not possible. The expected split size will be [4, 4, 4, 1].
+    sharded_dim_size_max = get_chunked_dim_size(weight.size(0), split_size, 0)
+    for idx, placement in enumerate(weight._sharding_spec.placements):
+        sharded_dim_size = get_chunked_dim_size(weight.size(0), split_size, idx)
+        start_row_idx = idx * sharded_dim_size_max
+        end_row_idx = start_row_idx + sharded_dim_size
+        start_idx = torch.searchsorted(input_sorted, start_row_idx).item()
+        end_idx = torch.searchsorted(input_sorted, end_row_idx).item()
+        input_split_sizes[placement.rank()] = int(end_idx - start_idx)
+        input_split_start_indices[placement.rank()] = int(start_idx)
+        if placement.rank() != idx:
+            rearrange_rows = True
+        # Store the range of the current rank.
+        if placement.rank() == rank:
+            start_row_idx_rank = start_row_idx
+            end_row_idx_rank = end_row_idx
+
+    # Perform the modular if padding_idx is within the range.
+    if padding_idx is not None:
+        if padding_idx < start_row_idx_rank or padding_idx >= end_row_idx_rank:
+            padding_idx = None
+        else:
+            padding_idx = padding_idx % sharded_dim_size_max
+
+    rearrange_indices_1d_second_order = None
+    if rearrange_rows:
+        # Need to re-arrange the 1D tensor to be sent via all2all.
+        indices: List[List[int]] = [[0]] * world_size
+        for placement in weight._sharding_spec.placements:
+            split_length = input_split_sizes[placement.rank()]
+            offset_idx = input_split_start_indices[placement.rank()]
+            indices[placement.rank()] = list(
+                range(offset_idx, offset_idx + split_length)
+            )
+        indices_flatten = list(idx for indice in indices for idx in indice)
+
+        input_sorted = input_sorted.index_select(
+            0, torch.tensor(indices_flatten, device=input.device)
+        )
+        rearrange_indices_1d_second_order = torch.argsort(torch.Tensor(indices_flatten))
+
+    return (
+        input_sorted,
+        input_split_sizes,
+        sharded_dim_size_max,
+        torch.tensor(indices_flatten, device=input.device) if rearrange_rows else None,
+        rearrange_indices_1d_second_order,
+        padding_idx,
+    )
+
+
+def _communicate_size_to_each_rank(
+    input_size_list, output_size, input, pg, tensor_type=torch.int
+):
+    """
+    In the circumstance of row-wise sharding of weight, we need to first
+    communicate the input length to each rank because each rank gets a
+    different one.
+
+    Args:
+        input_size_list: list of sizes to be sent to each rank.
+        output_size: length of the output tensor.
+        input: tensor to be applied op on.
+        pg: process group.
+        tensor_type: dtype of tensor.
+
+    Return: A list of communication results (int).
+    """
+    input_size_list_tensor = torch.tensor(
+        input_size_list, dtype=tensor_type, device=input.device
+    )
+    output_size_list_tensor = torch.empty(
+        output_size, dtype=tensor_type, device=input.device
+    )
+    dist.all_to_all_single(
+        output_size_list_tensor,
+        input_size_list_tensor,
+        group=pg,
+    )
+    return output_size_list_tensor.tolist()
+
+
+def _communicate_list_to_each_rank(
+    input_tensor_list, output_lists, input, pg, tensor_type=torch.int64
+):
+    """
+    In the circumstance of row-wise sharding of weight, we need to
+    communicate a list of input tensors to each rank. Because the
+    input could be a list of list, we need to first convert the list
+    to a tensor.
+
+    Args:
+        input_tensor_list: list of tensors to be sent to each rank.
+        output_lists: list of sizes to be obtained from each rank.
+        input: tensor to be applied op on.
+        pg: process group.
+        tensor_type: dtype of tensor.
+
+    Return: A list of communication results (tensors).
+    """
+    output_tensor_list = []
+    for output_list in output_lists:
+        output_tensor_list.append(
+            torch.empty(output_list, dtype=tensor_type, device=input.device)
+        )
+    dist.all_to_all(
+        output_tensor_list,
+        input_tensor_list,
+        group=pg,
+    )
+    return output_tensor_list
+
+
+def _handle_max_norm_col_wise(
+    max_norm,
+    norm_type,
+    local_shard,
+    input,
+    world_size,
+    pg,
+):
+    """
+    For col-wise sharding of weight, we need to aggregate the
+    norm across all ranks before we can perform the proper re-norm.
+    Note that, the max_norm logic is only applied to the embedding
+    indices that are looked up and not the whole shard.
+
+    Args:
+        max_norm: If given, each embedding vector with norm larger
+            than max_norm is renormalized to have norm max_norm.
+            Note: this will modify weight in-place.
+        norm_type: The p in the p-norm to compute for the max_norm option.
+        local_shard: col-wise shared local weight used for lookup.
+        input: tensor to be applied op to.
+        world_size: number of ranks.
+        pg: process group.
+
+    Return:
+        local_shard_norm_renormed: local_shard re-normed to max_norm if the norm is larger
+            than it.
+        gathered_inputs: list of inputs from all ranks.
+    """
+    norm_type = norm_type if norm_type is not None else 2.0
+    # allgather the inputs first.
+    gathered_inputs = [torch.zeros_like(input) for _ in range(world_size)]
+    dist.all_gather(gathered_inputs, input, group=pg)
+    unique_inp = torch.unique(torch.cat(gathered_inputs))
+    local_shard_sum = torch.sum(
+        torch.pow(torch.abs(local_shard), norm_type), dim=1, dtype=local_shard.dtype
+    )
+    # For col-wise sharding, we need to first aggregate the powered sum
+    # from each rank first and then calculate the norm.
+    dist.all_reduce(local_shard_sum, group=pg)
+    local_shard_norm = torch.pow(local_shard_sum, 1.0 / norm_type)
+    max_norm_tensor = torch.full(
+        (local_shard.size(0),),
+        float("inf"),
+        dtype=local_shard.dtype,
+        device=input.device,
+    )
+    max_norm_tensor[unique_inp] = max_norm
+    local_shard_t = local_shard.t().contiguous()
+    normalized_tensor = torch.where(
+        local_shard_norm > max_norm_tensor, max_norm_tensor, local_shard_norm
+    )
+    # Make sure divisor is not zero.
+    local_shard_norm[local_shard_norm == 0.0] = 1.0
+    local_shard_norm_renormed = (
+        torch.div(torch.mul(local_shard_t, normalized_tensor), local_shard_norm)
+        .t()
+        .contiguous()
+    )
+    return local_shard_norm_renormed, gathered_inputs
diff --git a/torch/distributed/_shard/sharded_tensor/_ops/embedding.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding.py
similarity index 97%
rename from torch/distributed/_shard/sharded_tensor/_ops/embedding.py
rename to torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding.py
index 11ef153..232e716 100644
--- a/torch/distributed/_shard/sharded_tensor/_ops/embedding.py
+++ b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding.py
@@ -11,13 +11,13 @@
     _handle_max_norm_col_wise,
 )
 from torch.distributed._shard.sharding_spec import ChunkShardingSpec
+from torch.distributed._shard.sharding_spec.api import custom_sharding_spec_op
 from torch.distributed._shard.sharded_tensor import (
-    sharded_op_impl,
     ShardedTensor
 )
 
-@sharded_op_impl(torch.nn.functional.embedding)
-def sharded_embedding(types, args, kwargs, pg):
+@custom_sharding_spec_op(ChunkShardingSpec, torch.nn.functional.embedding)
+def sharded_embedding(types, args, kwargs):
     """
     Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding``.
     This method computes a sharded embedding lookup and has the following limitations:
@@ -104,6 +104,7 @@
     norm_type = kwargs.get("norm_type")
     padding_idx = kwargs.get("padding_idx")
 
+    pg = weight._process_group
     local_shard = weight.local_tensor().contiguous()
     sharding_dim = weight._sharding_spec.dim
     world_size = dist.get_world_size(pg)
diff --git a/torch/distributed/_shard/sharded_tensor/_ops/embedding_bag.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py
similarity index 98%
rename from torch/distributed/_shard/sharded_tensor/_ops/embedding_bag.py
rename to torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py
index 6fad875..d3151a3 100644
--- a/torch/distributed/_shard/sharded_tensor/_ops/embedding_bag.py
+++ b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/embedding_bag.py
@@ -15,14 +15,14 @@
     _handle_max_norm_col_wise,
 )
 from torch.distributed._shard.sharding_spec import ChunkShardingSpec
+from torch.distributed._shard.sharding_spec.api import custom_sharding_spec_op
 from torch.distributed._shard.sharded_tensor import (
-    sharded_op_impl,
     ShardedTensor
 )
 
 
-@sharded_op_impl(torch.nn.functional.embedding_bag)
-def sharded_embedding_bag(types, args, kwargs, pg):
+@custom_sharding_spec_op(ChunkShardingSpec, torch.nn.functional.embedding_bag)
+def sharded_embedding_bag(types, args, kwargs):
     """
     Handles ``__torch_function__`` dispatch for ``torch.nn.functional.embedding_bag``.
     This method computes a sharded embedding bag aggregation and has the following limitations:
@@ -123,6 +123,7 @@
     include_last_offset = kwargs.get("include_last_offset")
     padding_idx = kwargs.get("padding_idx")
 
+    pg = weight._process_group
     local_shard = weight.local_tensor().contiguous()
     sharding_dim = weight._sharding_spec.dim
     world_size = dist.get_world_size(pg)
diff --git a/torch/distributed/_shard/sharded_tensor/_ops/linear.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/linear.py
similarity index 98%
rename from torch/distributed/_shard/sharded_tensor/_ops/linear.py
rename to torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/linear.py
index 560ba91..87eefd9 100644
--- a/torch/distributed/_shard/sharded_tensor/_ops/linear.py
+++ b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/linear.py
@@ -9,10 +9,10 @@
 )
 from torch.distributed._shard.partial_tensor import _PartialTensor
 from torch.distributed._shard.sharded_tensor import (
-    sharded_op_impl,
     ShardedTensor,
 )
 from torch.distributed._shard.sharding_spec import ChunkShardingSpec
+from torch.distributed._shard.sharding_spec.api import custom_sharding_spec_op
 from torch.distributed._shard.sharding_spec._internals import (
     get_split_size,
     get_chunked_dim_size,
@@ -24,8 +24,8 @@
 )
 
 
-@sharded_op_impl(torch.nn.functional.linear)
-def sharded_linear(types, args, kwargs, pg):
+@custom_sharding_spec_op(ChunkShardingSpec, torch.nn.functional.linear)
+def sharded_linear(types, args, kwargs):
     """
     Handles ``__torch_function__`` dispatch for ``torch.nn.functional.linear``.
     This method computes a sharded linear and has the following limitations:
@@ -99,6 +99,7 @@
     weight = args[1]
     bias = args[2]
 
+    pg = weight._process_group
     local_shard = weight.local_tensor()
     local_shard_t = local_shard.t().contiguous()
     sharding_dim = weight._sharding_spec.dim
diff --git a/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/math_ops.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/math_ops.py
new file mode 100644
index 0000000..5f5b59f
--- /dev/null
+++ b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/math_ops.py
@@ -0,0 +1,72 @@
+import torch
+from torch import Tensor
+from torch.distributed._shard.sharded_tensor import (
+    ShardedTensor,
+)
+from torch.distributed._shard.sharding_spec import ChunkShardingSpec
+from torch.distributed._shard.sharding_spec.api import custom_sharding_spec_op
+from torch.distributed._shard.sharded_tensor._ops.math_ops import binary_math_op_impl
+
+from ._common import (
+    _chunk_sharding_spec_check,
+)
+
+def register_math_op(op):
+    @custom_sharding_spec_op(ChunkShardingSpec, op)
+    def binary_math_op(types, args=(), kwargs=None):
+        """
+        Handles ``__torch_function__`` dispatch for the binary math ops
+        such as `torch.add`, `torch.mul`, `torch.div`, etc.
+        This method computes on ShardedTensor
+        """
+        if len(args) != 2:
+            raise ValueError("Only support binary math op on ShardedTensor for now!")
+        lhs = args[0]
+        rhs = args[1]
+        pg = lhs._process_group if isinstance(lhs, ShardedTensor) else rhs._process_group
+        # Validate types
+        if isinstance(lhs, ShardedTensor) and isinstance(rhs, ShardedTensor):
+            lhs_spec = lhs.sharding_spec()
+            rhs_spec = rhs.sharding_spec()
+            _chunk_sharding_spec_check(lhs_spec, op)
+            _chunk_sharding_spec_check(rhs_spec, op)
+
+            if lhs.size() == rhs.size() and lhs_spec.dim == rhs_spec.dim:  # type: ignore[attr-defined]
+                # perform local element-wise math op
+                res = op(lhs.local_tensor(), rhs.local_tensor())
+                return ShardedTensor._init_from_local_tensor(
+                    res,
+                    lhs_spec,
+                    lhs.size(),  # type: ignore[arg-type]
+                    process_group=pg)
+            else:
+                raise RuntimeError("Implicit broadcasting not supported yet!")
+        else:
+            # Try dispatch to ShardingSpec agnostic ops.
+            return binary_math_op_impl(op, types, args, kwargs, pg)
+
+binary_ops = [
+    # add
+    torch.add,
+    Tensor.add,
+    Tensor.__add__,
+    Tensor.__radd__,
+    # sub
+    torch.sub,
+    Tensor.sub,
+    Tensor.__sub__,
+    Tensor.__rsub__,
+    # mul
+    torch.mul,
+    Tensor.mul,
+    Tensor.__mul__,
+    Tensor.__rmul__,
+    # div
+    torch.div,
+    Tensor.div,
+    Tensor.__div__,
+    Tensor.__rdiv__,
+]
+
+for op in binary_ops:
+    register_math_op(op)
diff --git a/torch/distributed/_shard/sharded_tensor/_ops/matrix_ops.py b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/matrix_ops.py
similarity index 80%
rename from torch/distributed/_shard/sharded_tensor/_ops/matrix_ops.py
rename to torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/matrix_ops.py
index 4fdf9a3..5bf3fb2 100644
--- a/torch/distributed/_shard/sharded_tensor/_ops/matrix_ops.py
+++ b/torch/distributed/_shard/sharding_spec/chunk_sharding_spec_ops/matrix_ops.py
@@ -4,73 +4,14 @@
 import torch
 import torch.distributed as dist
 from torch.distributed._shard.sharded_tensor import (
-    Shard,
     ShardedTensor,
 )
 
 from ._common import (
     _chunk_sharding_spec_check,
     _register_sharded_op_on_local_tensor,
-    _register_sharded_op_on_local_shards,
 )
 
-
-def sharded_type_as_check(*args, **kwargs):
-    """
-    Perform extra checks for the sharded_type_as op such as the input needs to
-    be either a Tensor or ShardedTensor.
-
-    Args: same as ``torch.Tensor.type_as``.
-
-    Return: None
-    """
-    if len(args) < 2:
-        raise ValueError("Needs to give a tensor to cast type as!")
-    if not isinstance(args[1], torch.Tensor) and not isinstance(args[1], ShardedTensor):
-        raise ValueError("Needs to give a Tensor or ShardedTensor to cast type as!")
-
-
-def same_dtype(*args, **kwargs):
-    """
-    When the dtype is the same, return the original ShardedTensor.
-
-    Args: same as ``torch.Tensor.type_as``.
-
-    Return (bool): Whether to return early or not.
-    """
-    return args[0].dtype == args[1].dtype
-
-
-def sharded_type_as(args, kwargs, pg):
-    """
-    Handles ``__torch_function__`` dispatch for the ``torch.Tensor.type_as`` op.
-
-    Args: same as ``torch.Tensor.type_as``.
-
-    Return:
-        new_local_shards (List[Shard]): Local shards for the new sharded tensor.
-        st_meta (ShardedTensorMetadata): Metadata of the new sharded tensor.
-    """
-    st = args[0]
-    tensor = args[1]
-    if isinstance(tensor, ShardedTensor):
-        tensor = tensor.local_tensor()
-    new_local_shards = []
-    for shard in st.local_shards():
-        new_local_shards.append(Shard(shard.tensor.type_as(tensor), shard.metadata))
-    st_meta = copy.deepcopy(st._metadata)
-    st_meta.tensor_properties.dtype = tensor.dtype
-    return new_local_shards, st_meta
-
-
-_register_sharded_op_on_local_shards(
-    torch.Tensor.type_as,
-    early_stop_func=same_dtype,
-    extra_check=sharded_type_as_check,
-    customized_func=sharded_type_as,
-)
-
-
 def transpose_same_dim(*args, **kwargs):
     """
     When the dim0 and dim1 of transpose are the same, return the original ShardedTensor.
@@ -322,3 +263,68 @@
     extra_check=sharded_view_check,
     customized_func=sharded_view,
 )
+
+def sharded_bmm_check(*args, **kwargs):
+    """
+    Perform extra checks for the sharded_bmm op, for example, st2 needs to
+    be a sharded tensor and both tensors need to sharded by dim 0, etc.
+
+    Args: same as ``torch.bmm``.
+
+    Return: None
+    """
+    if len(args) < 2:
+        raise TypeError("Needs two tensors to perform torch.bmm.")
+    st = args[0]
+    st2 = args[1]
+    # Validate types
+    if not isinstance(st2, ShardedTensor):
+        raise TypeError("st2 needs to be a ShardedTensor for torch.bmm.")
+    _chunk_sharding_spec_check(st2.sharding_spec(), torch.bmm)
+    if st.dim() != 3 or st2.dim() != 3:
+        raise TypeError("both st and st2 need to be a 3D ShardedTensor")
+    if (
+        st.sharding_spec().dim != st2.sharding_spec().dim  # type: ignore[attr-defined]
+        or st.sharding_spec().dim != 0
+    ):
+        raise NotImplementedError(
+            "Only support performing bmm on tensors sharded on dim 0 now."
+        )
+    if st.sharding_spec().placements != st2.sharding_spec().placements:  # type: ignore[attr-defined]
+        raise NotImplementedError(
+            "Both st and st2 need to have same placements for bmm."
+        )
+
+def sharded_bmm(args, kwargs, pg):
+    """
+    Handles ``__torch_function__`` dispatch for the sharded_bmm op.
+
+    Warning: For now we only supports the case when both tensors are sharded
+             by dim 0 so that no local communication.
+
+    Args: same as ``torch.bmm``.
+
+    Return:
+        local_tensor (Tensor): New local tensor to build the sharded tensor.
+        sharding_spec (:class:`torch.distributed._shard.sharding_spec.ShardingSpec`):
+            sharding spec of the new sharded tensor.
+        new_st_size (torch.Size): Size of the new sharded tensor.
+    """
+    st = args[0]
+    st2 = args[1]
+    local_tensor = torch.bmm(st.local_tensor(), st2.local_tensor())
+    new_st_size = (*st.size()[:-1], st2.size(-1))
+    return local_tensor, st.sharding_spec(), new_st_size
+
+
+_register_sharded_op_on_local_tensor(
+    torch.Tensor.bmm,
+    extra_check=sharded_bmm_check,
+    customized_func=sharded_bmm,
+)
+
+_register_sharded_op_on_local_tensor(
+    torch.bmm,
+    extra_check=sharded_bmm_check,
+    customized_func=sharded_bmm,
+)