Make PartialTensot a torch.Tensor subclass

Pull Request resolved: https://github.com/pytorch/pytorch/pull/77626

Differential Revision: [D36435152](https://our.internmc.facebook.com/intern/diff/D36435152/)

Approved by: https://github.com/wanchaol
diff --git a/torch/distributed/_shard/partial_tensor.py b/torch/distributed/_shard/partial_tensor.py
index 56d383a..0416075 100644
--- a/torch/distributed/_shard/partial_tensor.py
+++ b/torch/distributed/_shard/partial_tensor.py
@@ -9,7 +9,6 @@
 from torch.distributed.nn.functional import (
     reduce_scatter,
 )
-from torch.overrides import handle_torch_function
 
 # Custom PartialTensor ops
 _PARTIAL_TENSOR_OPS: Dict[Callable, Callable] = {}
@@ -40,7 +39,7 @@
         return wrapper
     return decorator_sharded_func
 
-class _PartialTensor(object):
+class _PartialTensor(torch.Tensor):
     """
     PartialTensor is an abstraction to represent Tensors that need
     aggregation across multiple devices and multiple processes.
@@ -116,21 +115,31 @@
         tensor([8, 10]) # Rank 1
     """
 
-    def __init__(
-        self, local_shard, process_group=None, reduce_op=distributed_c10d.ReduceOp.SUM
-    ):
-        self.local_shard = local_shard
-        self._process_group = (
+    _process_group: distributed_c10d.ProcessGroup
+    _local_shard: torch.Tensor
+    _reduce_op: distributed_c10d.ReduceOp
+
+    __slots__ = ["_process_group", "_local_shard", "_reduce_op"]
+
+    def __new__(cls, local_shard, process_group=None, reduce_op=distributed_c10d.ReduceOp.SUM):
+        r = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
+            cls,
+            local_shard.size(),
+            dtype=local_shard.dtype,
+            layout=local_shard.layout,
+            pin_memory=local_shard.is_pinned(),
+            requires_grad=local_shard.requires_grad)      # type: ignore[arg-type]
+        r._process_group = (     # type: ignore[attr-defined]
             process_group
-            if process_group
-            else dist.distributed_c10d._get_default_group()
+            if process_group is not None
+            else distributed_c10d._get_default_group()
         )
-        self.reduce_op = reduce_op
+        r._reduce_op = reduce_op
+        r._local_shard = local_shard
+        return r
 
     def __post_init__(self):
-        if not isinstance(self.local_shard, torch.Tensor):
-            raise ValueError("local_shard needs to be a Tensor.")
-        if not isinstance(self.reduce_op, distributed_c10d.ReduceOp):
+        if not isinstance(self._reduce_op, distributed_c10d.ReduceOp):
             raise ValueError(
                 "reduce_op needs to be a member of distributed_c10d.ReduceOp."
             )
@@ -154,17 +163,17 @@
         """
         if not isinstance(resharding_spec, shard_spec.ChunkShardingSpec):
             raise NotImplementedError("Only ChunkShardingSpec supported for reshard.")
-        if self.local_shard.is_complex():
+        if self._local_shard.is_complex():
             raise NotImplementedError("Only real partial tensor supported for reshard.")
         sharding_dim = int(resharding_spec.dim)  # type: ignore[attr-defined]
-        chunk_mode_res = self.local_shard.size(sharding_dim) % self._process_group.size()
-        local_shard = self.local_shard
+        chunk_mode_res = self._local_shard.size(sharding_dim) % self._process_group.size()
+        local_shard = self._local_shard
         # Add padding when the size is not divisible by the world size.
         if chunk_mode_res != 0:
-            padding = [0] * (self.local_shard.dim() * 2)
+            padding = [0] * (local_shard.dim() * 2)
             padding[-1] = self._process_group.size() - chunk_mode_res
             local_shard = torch.nn.functional.pad(
-                self.local_shard,
+                local_shard,
                 tuple(padding),
                 "constant",
                 0,
@@ -185,16 +194,16 @@
             # Need to re-arrange original shard_dim of output_tensor_list.
             local_shards = [local_shards[idx] for idx in indices]  # type: ignore[call-overload]
         local_result = reduce_scatter(
-            torch.empty_like(local_shards[0]), list(local_shards), op=self.reduce_op
+            torch.empty_like(local_shards[0]), list(local_shards), op=self._reduce_op
         )
 
-        sharded_tensor_size = self.local_shard.size()
+        sharded_tensor_size = self._local_shard.size()
         # Remove padding when the size is not divisible by the world size.
         if chunk_mode_res != 0:
-            uneven_local_shards = self.local_shard.chunk(
+            uneven_local_shards = self._local_shard.chunk(
                 self._process_group.size(), dim=sharding_dim
             )
-            expected_size = uneven_local_shards[rank_idx].size()
+            expected_size = uneven_local_shards[rank_idx].size()  # type: ignore[index]
             if local_result.size() != expected_size:
                 local_result = local_result.narrow(
                     sharding_dim,
@@ -208,29 +217,41 @@
             process_group=self._process_group,
         )
 
-    def size(self):
-        return self.local_shard.size()
-
     @classmethod
     def __torch_function__(cls, func, types, args=(), kwargs=None):
         if func in _PARTIAL_TENSOR_OPS:
             return _PARTIAL_TENSOR_OPS[func](types, args, kwargs)
 
-        raise RuntimeError(
-            f"torch function '{func.__name__}', with args: {args} and "
-            f"kwargs: {kwargs} not supported for PartialTensor!")
+        # Need to disable all dispatch to print args and kwargs appropriately.
+        guard = torch._C._DisableTorchDispatch()  # type: ignore[attr-defined]
+        try:
+            with torch._C.DisableTorchFunction():
+                raise RuntimeError(
+                    f"torch function '{func.__name__}', with args: {args} and "
+                    f"kwargs: {kwargs} not supported for PartialTensor!")
+        finally:
+            del guard
 
-    def transpose(self, dim0, dim1):
-        return handle_torch_function(torch.Tensor.transpose, (self, dim0, dim1), self, dim0, dim1)
+    @classmethod
+    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
+        raise RuntimeError(
+            f"A {cls.__name__} object is being used from c++ "
+            f"while calling {func.__module__}.{func.__name__} "
+            "but the there is no custom __torch_dispatch__ implementation for it."
+        )
+
+    def __repr__(self):
+        return f"PartialTensor({super(_PartialTensor, self).__repr__()})"
 
 def _transpose_impl(types, args=(), kwargs=None):
-    input = args[0]
+    partial_tensor = args[0]
+    input = partial_tensor._local_shard
     dim0 = args[1]
     dim1 = args[2]
     return _PartialTensor(
-        torch.transpose(input.local_shard, dim0, dim1),
-        input._process_group,
-        input.reduce_op
+        torch.transpose(input, dim0, dim1),
+        partial_tensor._process_group,
+        partial_tensor._reduce_op
     )
 
 @_custom_partial_tensor_op(torch.Tensor.transpose)
@@ -252,11 +273,14 @@
         if not isinstance(input, _PartialTensor):
             raise RuntimeError('All inputs need to be an instance of _PartialTensor')
         if idx == 0:
-            reduce_op = input.reduce_op
-        elif reduce_op != input.reduce_op:
-            raise RuntimeError('All _PartialTensor reduce_ops need to be the same, found: {reduce_op} and {input.reduce_op}')
+            reduce_op = input._reduce_op
+        elif reduce_op != input._reduce_op:
+            raise RuntimeError(
+                'All _PartialTensor reduce_ops need to be the same, found: '
+                '{reduce_op} and {input._reduce_op}'
+            )
 
-        local_shards.append(input.local_shard)
+        local_shards.append(input._local_shard)
 
     if kwargs is None:
         dim = 0
@@ -264,4 +288,12 @@
         if 'out' in kwargs:
             raise RuntimeError('"out" kwarg is not supported!')
         dim = kwargs['dim'] if 'dim' in kwargs else 0
-    return _PartialTensor(torch.cat(local_shards, dim), input._process_group, input.reduce_op)
+
+    return _PartialTensor(torch.cat(local_shards, dim), input._process_group, input._reduce_op)
+
+@_custom_partial_tensor_op(torch.Tensor.size)
+def partial_size(types, args=(), kwargs=None):
+    if kwargs is None:
+        kwargs = {}
+    with torch._C.DisableTorchFunction():
+        return torch.Tensor.size(*args, **kwargs)