Make PartialTensot a torch.Tensor subclass
Pull Request resolved: https://github.com/pytorch/pytorch/pull/77626
Differential Revision: [D36435152](https://our.internmc.facebook.com/intern/diff/D36435152/)
Approved by: https://github.com/wanchaol
diff --git a/torch/distributed/_shard/partial_tensor.py b/torch/distributed/_shard/partial_tensor.py
index 56d383a..0416075 100644
--- a/torch/distributed/_shard/partial_tensor.py
+++ b/torch/distributed/_shard/partial_tensor.py
@@ -9,7 +9,6 @@
from torch.distributed.nn.functional import (
reduce_scatter,
)
-from torch.overrides import handle_torch_function
# Custom PartialTensor ops
_PARTIAL_TENSOR_OPS: Dict[Callable, Callable] = {}
@@ -40,7 +39,7 @@
return wrapper
return decorator_sharded_func
-class _PartialTensor(object):
+class _PartialTensor(torch.Tensor):
"""
PartialTensor is an abstraction to represent Tensors that need
aggregation across multiple devices and multiple processes.
@@ -116,21 +115,31 @@
tensor([8, 10]) # Rank 1
"""
- def __init__(
- self, local_shard, process_group=None, reduce_op=distributed_c10d.ReduceOp.SUM
- ):
- self.local_shard = local_shard
- self._process_group = (
+ _process_group: distributed_c10d.ProcessGroup
+ _local_shard: torch.Tensor
+ _reduce_op: distributed_c10d.ReduceOp
+
+ __slots__ = ["_process_group", "_local_shard", "_reduce_op"]
+
+ def __new__(cls, local_shard, process_group=None, reduce_op=distributed_c10d.ReduceOp.SUM):
+ r = torch.Tensor._make_wrapper_subclass( # type: ignore[attr-defined]
+ cls,
+ local_shard.size(),
+ dtype=local_shard.dtype,
+ layout=local_shard.layout,
+ pin_memory=local_shard.is_pinned(),
+ requires_grad=local_shard.requires_grad) # type: ignore[arg-type]
+ r._process_group = ( # type: ignore[attr-defined]
process_group
- if process_group
- else dist.distributed_c10d._get_default_group()
+ if process_group is not None
+ else distributed_c10d._get_default_group()
)
- self.reduce_op = reduce_op
+ r._reduce_op = reduce_op
+ r._local_shard = local_shard
+ return r
def __post_init__(self):
- if not isinstance(self.local_shard, torch.Tensor):
- raise ValueError("local_shard needs to be a Tensor.")
- if not isinstance(self.reduce_op, distributed_c10d.ReduceOp):
+ if not isinstance(self._reduce_op, distributed_c10d.ReduceOp):
raise ValueError(
"reduce_op needs to be a member of distributed_c10d.ReduceOp."
)
@@ -154,17 +163,17 @@
"""
if not isinstance(resharding_spec, shard_spec.ChunkShardingSpec):
raise NotImplementedError("Only ChunkShardingSpec supported for reshard.")
- if self.local_shard.is_complex():
+ if self._local_shard.is_complex():
raise NotImplementedError("Only real partial tensor supported for reshard.")
sharding_dim = int(resharding_spec.dim) # type: ignore[attr-defined]
- chunk_mode_res = self.local_shard.size(sharding_dim) % self._process_group.size()
- local_shard = self.local_shard
+ chunk_mode_res = self._local_shard.size(sharding_dim) % self._process_group.size()
+ local_shard = self._local_shard
# Add padding when the size is not divisible by the world size.
if chunk_mode_res != 0:
- padding = [0] * (self.local_shard.dim() * 2)
+ padding = [0] * (local_shard.dim() * 2)
padding[-1] = self._process_group.size() - chunk_mode_res
local_shard = torch.nn.functional.pad(
- self.local_shard,
+ local_shard,
tuple(padding),
"constant",
0,
@@ -185,16 +194,16 @@
# Need to re-arrange original shard_dim of output_tensor_list.
local_shards = [local_shards[idx] for idx in indices] # type: ignore[call-overload]
local_result = reduce_scatter(
- torch.empty_like(local_shards[0]), list(local_shards), op=self.reduce_op
+ torch.empty_like(local_shards[0]), list(local_shards), op=self._reduce_op
)
- sharded_tensor_size = self.local_shard.size()
+ sharded_tensor_size = self._local_shard.size()
# Remove padding when the size is not divisible by the world size.
if chunk_mode_res != 0:
- uneven_local_shards = self.local_shard.chunk(
+ uneven_local_shards = self._local_shard.chunk(
self._process_group.size(), dim=sharding_dim
)
- expected_size = uneven_local_shards[rank_idx].size()
+ expected_size = uneven_local_shards[rank_idx].size() # type: ignore[index]
if local_result.size() != expected_size:
local_result = local_result.narrow(
sharding_dim,
@@ -208,29 +217,41 @@
process_group=self._process_group,
)
- def size(self):
- return self.local_shard.size()
-
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
if func in _PARTIAL_TENSOR_OPS:
return _PARTIAL_TENSOR_OPS[func](types, args, kwargs)
- raise RuntimeError(
- f"torch function '{func.__name__}', with args: {args} and "
- f"kwargs: {kwargs} not supported for PartialTensor!")
+ # Need to disable all dispatch to print args and kwargs appropriately.
+ guard = torch._C._DisableTorchDispatch() # type: ignore[attr-defined]
+ try:
+ with torch._C.DisableTorchFunction():
+ raise RuntimeError(
+ f"torch function '{func.__name__}', with args: {args} and "
+ f"kwargs: {kwargs} not supported for PartialTensor!")
+ finally:
+ del guard
- def transpose(self, dim0, dim1):
- return handle_torch_function(torch.Tensor.transpose, (self, dim0, dim1), self, dim0, dim1)
+ @classmethod
+ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
+ raise RuntimeError(
+ f"A {cls.__name__} object is being used from c++ "
+ f"while calling {func.__module__}.{func.__name__} "
+ "but the there is no custom __torch_dispatch__ implementation for it."
+ )
+
+ def __repr__(self):
+ return f"PartialTensor({super(_PartialTensor, self).__repr__()})"
def _transpose_impl(types, args=(), kwargs=None):
- input = args[0]
+ partial_tensor = args[0]
+ input = partial_tensor._local_shard
dim0 = args[1]
dim1 = args[2]
return _PartialTensor(
- torch.transpose(input.local_shard, dim0, dim1),
- input._process_group,
- input.reduce_op
+ torch.transpose(input, dim0, dim1),
+ partial_tensor._process_group,
+ partial_tensor._reduce_op
)
@_custom_partial_tensor_op(torch.Tensor.transpose)
@@ -252,11 +273,14 @@
if not isinstance(input, _PartialTensor):
raise RuntimeError('All inputs need to be an instance of _PartialTensor')
if idx == 0:
- reduce_op = input.reduce_op
- elif reduce_op != input.reduce_op:
- raise RuntimeError('All _PartialTensor reduce_ops need to be the same, found: {reduce_op} and {input.reduce_op}')
+ reduce_op = input._reduce_op
+ elif reduce_op != input._reduce_op:
+ raise RuntimeError(
+ 'All _PartialTensor reduce_ops need to be the same, found: '
+ '{reduce_op} and {input._reduce_op}'
+ )
- local_shards.append(input.local_shard)
+ local_shards.append(input._local_shard)
if kwargs is None:
dim = 0
@@ -264,4 +288,12 @@
if 'out' in kwargs:
raise RuntimeError('"out" kwarg is not supported!')
dim = kwargs['dim'] if 'dim' in kwargs else 0
- return _PartialTensor(torch.cat(local_shards, dim), input._process_group, input.reduce_op)
+
+ return _PartialTensor(torch.cat(local_shards, dim), input._process_group, input._reduce_op)
+
+@_custom_partial_tensor_op(torch.Tensor.size)
+def partial_size(types, args=(), kwargs=None):
+ if kwargs is None:
+ kwargs = {}
+ with torch._C.DisableTorchFunction():
+ return torch.Tensor.size(*args, **kwargs)