[shard] Add ReplicatedTensor (#73529)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73529

Add ReplicatedTensor, a ReplicatedTensor is a type of tensor that have the same value on all ranks across the world_size.

ReplicatedTensor is a :class:`~torch.Tensor` subclass, and it could be used together with ShardedTensor/Tensor together to express different types of computation. The inter-op rules defined as (using torch.add as an example op):
    ReplicatedTensor + ReplicatedTensor = ReplicatedTensor
    ReplicatedTensor + torch.Tensor = torch.Tensor
    ReplicatedTensor + ShardedTensor = ShardedTensor

We also added a `validate()` API to help user validate if a replicated tensor on certain process_group is truly replicated or not.

TODO: next PR gonna add ShardedTensor/PartialTensor logic to handle ReplicatedTensor.
ghstack-source-id: 152064781

Test Plan: test_replicated_tensor

Reviewed By: pritamdamania87, fduwjj

Differential Revision: D34529374

fbshipit-source-id: 16ccb300e9f9c47ac29a17eb6d46d029ab7d60b8
(cherry picked from commit 44f4e11e795a1bf330a8108bda256950ca769525)
diff --git a/test/distributed/_shard/test_replicated_tensor.py b/test/distributed/_shard/test_replicated_tensor.py
new file mode 100644
index 0000000..474fbfb
--- /dev/null
+++ b/test/distributed/_shard/test_replicated_tensor.py
@@ -0,0 +1,76 @@
+# Owner(s): ["oncall: distributed"]
+
+import torch
+
+import torch.distributed as dist
+
+from torch.testing._internal.common_distributed import (
+    requires_nccl,
+    skip_if_lt_x_gpu,
+)
+
+from torch.testing._internal.distributed._shard.sharded_tensor import (
+    ShardedTensorTestBase,
+    with_comms,
+)
+from torch.distributed._shard.replicated_tensor import ReplicatedTensor
+
+
+class TestReplicatedTensor(ShardedTensorTestBase):
+
+    @with_comms(init_rpc=False)
+    @skip_if_lt_x_gpu(4)
+    @requires_nccl()
+    def test_replicated_tensor_basics(self):
+        local_tensor = torch.ones(3, 3, device=f"cuda:{self.rank}") * 4
+        replica_tensor = ReplicatedTensor(local_tensor)
+        print(replica_tensor.process_group)
+        # validate it's a replicated tensor by checking values on all rank
+        validated = replica_tensor.validate()
+        self.assertEqual(validated, True)
+        res = replica_tensor + 2
+        self.assertIsInstance(res, torch.Tensor)
+        self.assertNotIsInstance(res, ReplicatedTensor)
+        self.assertEqual(res, torch.ones(3, 3) * 6)
+
+        # modify local tensor on certain rank, and test if validation raise
+        if self.rank == 2:
+            local_tensor += 3
+
+        with self.assertRaisesRegex(ValueError, 'have different values'):
+            replica_tensor.validate()
+
+    @with_comms(init_rpc=False)
+    @skip_if_lt_x_gpu(4)
+    @requires_nccl()
+    def test_replicated_tensor_inter_op_replicated_tensor(self):
+        local_tensor = torch.ones(3, 3, device=f"cuda:{self.rank}")
+        replica_tensor1 = ReplicatedTensor(local_tensor * 4)
+        replica_tensor2 = ReplicatedTensor(local_tensor * 6)
+
+        new_tensor = replica_tensor1 * replica_tensor2
+        self.assertIsInstance(new_tensor, ReplicatedTensor)
+        self.assertEqual(new_tensor, torch.ones(3, 3) * 24)
+
+        # test replicated tensor inter-op with different pgs
+        new_pg = dist.new_group(ranks=[1, 2, 3])
+        replica_tensor_new_group = ReplicatedTensor(local_tensor * 3, process_group=new_pg)
+
+        with self.assertRaisesRegex(RuntimeError, 'must be in the same'):
+            replica_tensor_new_group * replica_tensor1
+
+
+    @with_comms(init_rpc=False)
+    @skip_if_lt_x_gpu(4)
+    @requires_nccl()
+    def test_replicated_tensor_inter_op_tensor(self):
+        local_tensor = torch.ones(3, 3, device=f"cuda:{self.rank}") * 4
+        replica_tensor = ReplicatedTensor(local_tensor)
+
+        local_rand_tensor = torch.randn(3, 3, device=f"cuda:{self.rank}")
+
+        new_tensor = replica_tensor + local_rand_tensor
+        self.assertIsInstance(new_tensor, torch.Tensor)
+        self.assertNotIsInstance(new_tensor, ReplicatedTensor)
+
+        self.assertEqual(new_tensor, local_tensor + local_rand_tensor)
diff --git a/test/run_test.py b/test/run_test.py
index 176b35b..d378f15 100644
--- a/test/run_test.py
+++ b/test/run_test.py
@@ -214,6 +214,7 @@
     "distributed/_shard/sharded_tensor/ops/test_linear",
     "distributed/_shard/sharding_spec/test_sharding_spec",
     "distributed/_shard/sharded_optim/test_sharded_optim",
+    "distributed/_shard/test_replicated_tensor",
 ] + FSDP_TEST
 
 ROCM_BLOCKLIST = [
@@ -233,6 +234,7 @@
     "distributed/_shard/sharded_tensor/ops/test_linear",
     "distributed/_shard/sharding_spec/test_sharding_spec",
     "distributed/_shard/sharded_optim/test_sharded_optim",
+    "distributed/_shard/test_replicated_tensor",
     "test_determination",
     "test_jit_legacy",
     "test_type_hints",
diff --git a/torch/distributed/_shard/__init__.py b/torch/distributed/_shard/__init__.py
index b6f0776..194ae2c 100644
--- a/torch/distributed/_shard/__init__.py
+++ b/torch/distributed/_shard/__init__.py
@@ -1 +1 @@
-from .api import shard_parameter, _shard_tensor
+from .api import shard_parameter, _shard_tensor, _replicate_tensor
diff --git a/torch/distributed/_shard/api.py b/torch/distributed/_shard/api.py
index 0de8a59..c5082b0 100644
--- a/torch/distributed/_shard/api.py
+++ b/torch/distributed/_shard/api.py
@@ -7,6 +7,7 @@
 from .sharding_spec import (
     ShardingSpec,
 )
+from .replicated_tensor import ReplicatedTensor
 
 def _shard_tensor(
     tensor: torch.Tensor, sharding_spec: ShardingSpec, src_rank=0, process_group=None
@@ -118,3 +119,20 @@
 
     # Now we can set the attribute appropriately.
     setattr(module, param_name, st)
+
+
+def _replicate_tensor(tensor: torch.Tensor, process_group=None) -> ReplicatedTensor:
+    """
+    Given a :class:`torch.Tensor`, mark it as a ReplicatedTensor where all
+    ranks have the same value.
+
+    Args:
+        tensor (:class:`torch.Tensor`): the tensor to be marked as replicated.
+    Keyword args:
+        process_group (ProcessGroup, optional): The process group to replicate on.
+            If None, the default process group will be used.
+    Returns:
+        A :class:`ReplicatedTensor` from the given tensor.
+
+    """
+    return ReplicatedTensor(tensor, process_group=process_group)
diff --git a/torch/distributed/_shard/replicated_tensor.py b/torch/distributed/_shard/replicated_tensor.py
new file mode 100644
index 0000000..12253a0
--- /dev/null
+++ b/torch/distributed/_shard/replicated_tensor.py
@@ -0,0 +1,125 @@
+import torch
+import torch.distributed as dist
+
+from torch.overrides import get_default_nowrap_functions
+from torch.distributed._shard.sharded_tensor.api import ShardedTensor
+from torch.distributed import distributed_c10d
+
+
+class ReplicatedTensor(torch.Tensor):
+    """
+    ReplicatedTensor represents a tensor which is replicated across the `world_size` and
+    has the same value on each rank.
+
+    ReplicatedTensor is a :class:`~torch.Tensor` subclass, and it could be used together
+    with ShardedTensor/Tensor together to express different types of computation. The
+    inter-op rules defined as (using torch.add as an example op):
+        ReplicatedTensor + ReplicatedTensor = ReplicatedTensor
+        ReplicatedTensor + torch.Tensor = torch.Tensor
+        ReplicatedTensor + ShardedTensor = ShardedTensor
+        ReplicatedTensor + other type (i.e. Scalar) = other type
+
+    NOTE: We do not gurantee equal content of ReplicatedTensor across nodes after its
+    construction. Although we defined proper inter-op rules to make sure ReplicatedTensor
+    stays the same, there's no enforcement on it (i.e. if you manually modify content on
+    some ranks, the modified value will not automatically get synced to other nodes). If
+    you wish to manually validate tensors are the same across ranks, use `validate()`.
+
+    """
+    process_group: distributed_c10d.ProcessGroup
+
+    __slots__ = ["process_group"]
+
+    def __new__(cls, data=None, process_group=None):
+        if data is None:
+            data = torch.empty(0)
+        r = torch.Tensor._make_subclass(cls, data)      # type: ignore[arg-type]
+        r.process_group = (     # type: ignore[attr-defined]
+            process_group
+            if process_group is not None
+            else distributed_c10d._get_default_group()
+        )
+        return r
+
+    def __repr__(self):
+        return f"ReplicatedTensor({super(ReplicatedTensor, self).__repr__()})"
+
+    @classmethod
+    def __torch_function__(cls, func, types, args=(), kwargs=None):
+        if kwargs is None:
+            kwargs = {}
+        # We will re-dispatch the execution to ShardedTensor __torch_function__
+        # if we find there're ShardedTensor operands. We will also check if args/kwargs
+        # are all replicated tensor operands, we have to do this to ensure we do not
+        # converting results back to ReplicatedTensor if not all operands are replicated.
+        all_replicated = True
+        replicated_pg = None
+
+        def dispatch_arg(arg):
+            nonlocal replicated_pg, all_replicated
+            if isinstance(arg, ShardedTensor):
+                # redispatch to ShardedTensor
+                # TODO: handle ShardedTensor/PartialTensor inter-op with ReplicatedTensor
+                return arg.__torch_function__(func, types, args, kwargs)
+            if isinstance(arg, ReplicatedTensor):
+                if replicated_pg is None:
+                    replicated_pg = arg.process_group
+                elif replicated_pg != arg.process_group:
+                    raise RuntimeError(
+                        f"ReplicatedTensor operands must be in the same process group "
+                        f"in torch function '{func.__name__}', but found at least two "
+                        f"ReplicatedTensor operands in different process groups! ")
+            else:
+                all_replicated = False
+
+        for arg in args:
+            dispatch_arg(arg)
+
+        if kwargs is not None:
+            for k, v in kwargs.items():
+                dispatch_arg(v)
+
+        # We cann't do super().__torch_function__() as it implicitly convert the result
+        # back to tensor subclasses, where in our case, we need to control the output type
+        # base on the inter-op rules we defined.
+        with torch._C.DisableTorchFunction():
+            rs = func(*args, **kwargs)
+            if func in get_default_nowrap_functions():
+                return rs
+            if all_replicated and isinstance(rs, torch.Tensor) and not isinstance(rs, cls):
+                # if all operands are ReplicatedTensors and does not get dispatched to ShardedTensor
+                # __torch_function__, result is a torch.Tensor, then we convert and return a
+                # ReplicatedTensor according to our inter-op rule
+                rs = rs.as_subclass(cls)        # type: ignore[arg-type]
+                # propagate the process_group field to result
+                rs.process_group = replicated_pg        # type: ignore[attr-defined]
+
+            return rs
+
+    def validate(self) -> bool:
+        """
+        Validate the ReplicatedTensor is legit by all gathering tensors on all ranks
+        and check to make sure they are the same.
+
+        If there's some ranks with different values, a ValueError will be raised.
+
+        Keyword args:
+            process_group (ProcessGroup, optional): The process group to work on. If None,
+                the default process group will be used.
+
+        Returns:
+            True if validation succeed.
+        """
+        world_size = dist.get_world_size(self.process_group)
+        current_rank = dist.get_rank(self.process_group)
+
+        tensors_on_rank = [torch.empty_like(self) for _ in range(world_size)]
+
+        dist.all_gather(tensors_on_rank, self, group=self.process_group)
+        # validate and check if all tensors are equal
+        for rank, tensor in enumerate(tensors_on_rank):
+            if not torch.allclose(self, tensor):
+                raise ValueError(
+                    f"ReplicatedTensor have different values on rank {current_rank} and {rank}")
+
+        return True
diff --git a/torch/distributed/_shard/sharded_tensor/__init__.py b/torch/distributed/_shard/sharded_tensor/__init__.py
index ba1cdf3..58b5e02 100644
--- a/torch/distributed/_shard/sharded_tensor/__init__.py
+++ b/torch/distributed/_shard/sharded_tensor/__init__.py
@@ -366,7 +366,7 @@
     parameters, the function provided will be invoked for that operator.
 
     Example::
-        >>> @custom_sharded_op(torch.nn.functional.linear)
+        >>> @sharded_op_impl(torch.nn.functional.linear)
         >>> def my_custom_sharded_linear(types, args, kwargs, process_group):
         >>>   ....
         >>>