[shard] Add ReplicatedTensor (#73529)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73529
Add ReplicatedTensor, a ReplicatedTensor is a type of tensor that have the same value on all ranks across the world_size.
ReplicatedTensor is a :class:`~torch.Tensor` subclass, and it could be used together with ShardedTensor/Tensor together to express different types of computation. The inter-op rules defined as (using torch.add as an example op):
ReplicatedTensor + ReplicatedTensor = ReplicatedTensor
ReplicatedTensor + torch.Tensor = torch.Tensor
ReplicatedTensor + ShardedTensor = ShardedTensor
We also added a `validate()` API to help user validate if a replicated tensor on certain process_group is truly replicated or not.
TODO: next PR gonna add ShardedTensor/PartialTensor logic to handle ReplicatedTensor.
ghstack-source-id: 152064781
Test Plan: test_replicated_tensor
Reviewed By: pritamdamania87, fduwjj
Differential Revision: D34529374
fbshipit-source-id: 16ccb300e9f9c47ac29a17eb6d46d029ab7d60b8
(cherry picked from commit 44f4e11e795a1bf330a8108bda256950ca769525)
diff --git a/test/distributed/_shard/test_replicated_tensor.py b/test/distributed/_shard/test_replicated_tensor.py
new file mode 100644
index 0000000..474fbfb
--- /dev/null
+++ b/test/distributed/_shard/test_replicated_tensor.py
@@ -0,0 +1,76 @@
+# Owner(s): ["oncall: distributed"]
+
+import torch
+
+import torch.distributed as dist
+
+from torch.testing._internal.common_distributed import (
+ requires_nccl,
+ skip_if_lt_x_gpu,
+)
+
+from torch.testing._internal.distributed._shard.sharded_tensor import (
+ ShardedTensorTestBase,
+ with_comms,
+)
+from torch.distributed._shard.replicated_tensor import ReplicatedTensor
+
+
+class TestReplicatedTensor(ShardedTensorTestBase):
+
+ @with_comms(init_rpc=False)
+ @skip_if_lt_x_gpu(4)
+ @requires_nccl()
+ def test_replicated_tensor_basics(self):
+ local_tensor = torch.ones(3, 3, device=f"cuda:{self.rank}") * 4
+ replica_tensor = ReplicatedTensor(local_tensor)
+ print(replica_tensor.process_group)
+ # validate it's a replicated tensor by checking values on all rank
+ validated = replica_tensor.validate()
+ self.assertEqual(validated, True)
+ res = replica_tensor + 2
+ self.assertIsInstance(res, torch.Tensor)
+ self.assertNotIsInstance(res, ReplicatedTensor)
+ self.assertEqual(res, torch.ones(3, 3) * 6)
+
+ # modify local tensor on certain rank, and test if validation raise
+ if self.rank == 2:
+ local_tensor += 3
+
+ with self.assertRaisesRegex(ValueError, 'have different values'):
+ replica_tensor.validate()
+
+ @with_comms(init_rpc=False)
+ @skip_if_lt_x_gpu(4)
+ @requires_nccl()
+ def test_replicated_tensor_inter_op_replicated_tensor(self):
+ local_tensor = torch.ones(3, 3, device=f"cuda:{self.rank}")
+ replica_tensor1 = ReplicatedTensor(local_tensor * 4)
+ replica_tensor2 = ReplicatedTensor(local_tensor * 6)
+
+ new_tensor = replica_tensor1 * replica_tensor2
+ self.assertIsInstance(new_tensor, ReplicatedTensor)
+ self.assertEqual(new_tensor, torch.ones(3, 3) * 24)
+
+ # test replicated tensor inter-op with different pgs
+ new_pg = dist.new_group(ranks=[1, 2, 3])
+ replica_tensor_new_group = ReplicatedTensor(local_tensor * 3, process_group=new_pg)
+
+ with self.assertRaisesRegex(RuntimeError, 'must be in the same'):
+ replica_tensor_new_group * replica_tensor1
+
+
+ @with_comms(init_rpc=False)
+ @skip_if_lt_x_gpu(4)
+ @requires_nccl()
+ def test_replicated_tensor_inter_op_tensor(self):
+ local_tensor = torch.ones(3, 3, device=f"cuda:{self.rank}") * 4
+ replica_tensor = ReplicatedTensor(local_tensor)
+
+ local_rand_tensor = torch.randn(3, 3, device=f"cuda:{self.rank}")
+
+ new_tensor = replica_tensor + local_rand_tensor
+ self.assertIsInstance(new_tensor, torch.Tensor)
+ self.assertNotIsInstance(new_tensor, ReplicatedTensor)
+
+ self.assertEqual(new_tensor, local_tensor + local_rand_tensor)
diff --git a/test/run_test.py b/test/run_test.py
index 176b35b..d378f15 100644
--- a/test/run_test.py
+++ b/test/run_test.py
@@ -214,6 +214,7 @@
"distributed/_shard/sharded_tensor/ops/test_linear",
"distributed/_shard/sharding_spec/test_sharding_spec",
"distributed/_shard/sharded_optim/test_sharded_optim",
+ "distributed/_shard/test_replicated_tensor",
] + FSDP_TEST
ROCM_BLOCKLIST = [
@@ -233,6 +234,7 @@
"distributed/_shard/sharded_tensor/ops/test_linear",
"distributed/_shard/sharding_spec/test_sharding_spec",
"distributed/_shard/sharded_optim/test_sharded_optim",
+ "distributed/_shard/test_replicated_tensor",
"test_determination",
"test_jit_legacy",
"test_type_hints",
diff --git a/torch/distributed/_shard/__init__.py b/torch/distributed/_shard/__init__.py
index b6f0776..194ae2c 100644
--- a/torch/distributed/_shard/__init__.py
+++ b/torch/distributed/_shard/__init__.py
@@ -1 +1 @@
-from .api import shard_parameter, _shard_tensor
+from .api import shard_parameter, _shard_tensor, _replicate_tensor
diff --git a/torch/distributed/_shard/api.py b/torch/distributed/_shard/api.py
index 0de8a59..c5082b0 100644
--- a/torch/distributed/_shard/api.py
+++ b/torch/distributed/_shard/api.py
@@ -7,6 +7,7 @@
from .sharding_spec import (
ShardingSpec,
)
+from .replicated_tensor import ReplicatedTensor
def _shard_tensor(
tensor: torch.Tensor, sharding_spec: ShardingSpec, src_rank=0, process_group=None
@@ -118,3 +119,20 @@
# Now we can set the attribute appropriately.
setattr(module, param_name, st)
+
+
+def _replicate_tensor(tensor: torch.Tensor, process_group=None) -> ReplicatedTensor:
+ """
+ Given a :class:`torch.Tensor`, mark it as a ReplicatedTensor where all
+ ranks have the same value.
+
+ Args:
+ tensor (:class:`torch.Tensor`): the tensor to be marked as replicated.
+ Keyword args:
+ process_group (ProcessGroup, optional): The process group to replicate on.
+ If None, the default process group will be used.
+ Returns:
+ A :class:`ReplicatedTensor` from the given tensor.
+
+ """
+ return ReplicatedTensor(tensor, process_group=process_group)
diff --git a/torch/distributed/_shard/replicated_tensor.py b/torch/distributed/_shard/replicated_tensor.py
new file mode 100644
index 0000000..12253a0
--- /dev/null
+++ b/torch/distributed/_shard/replicated_tensor.py
@@ -0,0 +1,125 @@
+import torch
+import torch.distributed as dist
+
+from torch.overrides import get_default_nowrap_functions
+from torch.distributed._shard.sharded_tensor.api import ShardedTensor
+from torch.distributed import distributed_c10d
+
+
+class ReplicatedTensor(torch.Tensor):
+ """
+ ReplicatedTensor represents a tensor which is replicated across the `world_size` and
+ has the same value on each rank.
+
+ ReplicatedTensor is a :class:`~torch.Tensor` subclass, and it could be used together
+ with ShardedTensor/Tensor together to express different types of computation. The
+ inter-op rules defined as (using torch.add as an example op):
+ ReplicatedTensor + ReplicatedTensor = ReplicatedTensor
+ ReplicatedTensor + torch.Tensor = torch.Tensor
+ ReplicatedTensor + ShardedTensor = ShardedTensor
+ ReplicatedTensor + other type (i.e. Scalar) = other type
+
+ NOTE: We do not gurantee equal content of ReplicatedTensor across nodes after its
+ construction. Although we defined proper inter-op rules to make sure ReplicatedTensor
+ stays the same, there's no enforcement on it (i.e. if you manually modify content on
+ some ranks, the modified value will not automatically get synced to other nodes). If
+ you wish to manually validate tensors are the same across ranks, use `validate()`.
+
+ """
+ process_group: distributed_c10d.ProcessGroup
+
+ __slots__ = ["process_group"]
+
+ def __new__(cls, data=None, process_group=None):
+ if data is None:
+ data = torch.empty(0)
+ r = torch.Tensor._make_subclass(cls, data) # type: ignore[arg-type]
+ r.process_group = ( # type: ignore[attr-defined]
+ process_group
+ if process_group is not None
+ else distributed_c10d._get_default_group()
+ )
+ return r
+
+ def __repr__(self):
+ return f"ReplicatedTensor({super(ReplicatedTensor, self).__repr__()})"
+
+ @classmethod
+ def __torch_function__(cls, func, types, args=(), kwargs=None):
+ if kwargs is None:
+ kwargs = {}
+ # We will re-dispatch the execution to ShardedTensor __torch_function__
+ # if we find there're ShardedTensor operands. We will also check if args/kwargs
+ # are all replicated tensor operands, we have to do this to ensure we do not
+ # converting results back to ReplicatedTensor if not all operands are replicated.
+ all_replicated = True
+ replicated_pg = None
+
+ def dispatch_arg(arg):
+ nonlocal replicated_pg, all_replicated
+ if isinstance(arg, ShardedTensor):
+ # redispatch to ShardedTensor
+ # TODO: handle ShardedTensor/PartialTensor inter-op with ReplicatedTensor
+ return arg.__torch_function__(func, types, args, kwargs)
+ if isinstance(arg, ReplicatedTensor):
+ if replicated_pg is None:
+ replicated_pg = arg.process_group
+ elif replicated_pg != arg.process_group:
+ raise RuntimeError(
+ f"ReplicatedTensor operands must be in the same process group "
+ f"in torch function '{func.__name__}', but found at least two "
+ f"ReplicatedTensor operands in different process groups! ")
+ else:
+ all_replicated = False
+
+ for arg in args:
+ dispatch_arg(arg)
+
+ if kwargs is not None:
+ for k, v in kwargs.items():
+ dispatch_arg(v)
+
+ # We cann't do super().__torch_function__() as it implicitly convert the result
+ # back to tensor subclasses, where in our case, we need to control the output type
+ # base on the inter-op rules we defined.
+ with torch._C.DisableTorchFunction():
+ rs = func(*args, **kwargs)
+ if func in get_default_nowrap_functions():
+ return rs
+ if all_replicated and isinstance(rs, torch.Tensor) and not isinstance(rs, cls):
+ # if all operands are ReplicatedTensors and does not get dispatched to ShardedTensor
+ # __torch_function__, result is a torch.Tensor, then we convert and return a
+ # ReplicatedTensor according to our inter-op rule
+ rs = rs.as_subclass(cls) # type: ignore[arg-type]
+ # propagate the process_group field to result
+ rs.process_group = replicated_pg # type: ignore[attr-defined]
+
+ return rs
+
+ def validate(self) -> bool:
+ """
+ Validate the ReplicatedTensor is legit by all gathering tensors on all ranks
+ and check to make sure they are the same.
+
+ If there's some ranks with different values, a ValueError will be raised.
+
+ Keyword args:
+ process_group (ProcessGroup, optional): The process group to work on. If None,
+ the default process group will be used.
+
+ Returns:
+ True if validation succeed.
+ """
+ world_size = dist.get_world_size(self.process_group)
+ current_rank = dist.get_rank(self.process_group)
+
+ tensors_on_rank = [torch.empty_like(self) for _ in range(world_size)]
+
+ dist.all_gather(tensors_on_rank, self, group=self.process_group)
+ # validate and check if all tensors are equal
+ for rank, tensor in enumerate(tensors_on_rank):
+ if not torch.allclose(self, tensor):
+ raise ValueError(
+ f"ReplicatedTensor have different values on rank {current_rank} and {rank}")
+
+ return True
diff --git a/torch/distributed/_shard/sharded_tensor/__init__.py b/torch/distributed/_shard/sharded_tensor/__init__.py
index ba1cdf3..58b5e02 100644
--- a/torch/distributed/_shard/sharded_tensor/__init__.py
+++ b/torch/distributed/_shard/sharded_tensor/__init__.py
@@ -366,7 +366,7 @@
parameters, the function provided will be invoked for that operator.
Example::
- >>> @custom_sharded_op(torch.nn.functional.linear)
+ >>> @sharded_op_impl(torch.nn.functional.linear)
>>> def my_custom_sharded_linear(types, args, kwargs, process_group):
>>> ....
>>>