[torchrec][pt-d][model store] introduce LocalShardsWrapper for DTensor (#129150)

Summary:
Same as D57688538, recreated because of GH issues

This diff introduces LocalShardsWrapper which is crucial to migrating from using ShardedTensor to DTensor in TRec state dict representation. As well as any changes needed in PT-D and ModelStore to support this.

It allows us to extend DTensor to support multiple shards on a rank as well as empty shards on a rank as needed by TRec sharding logic.

This diff also extends the support for LocalShardsWrapper to be used in conjunction with DTensor in checkpointing cases (ModelStore and DCP)

See D54375878 for how it is used.

**LocalShardsWrapper supports the following torch ops:**
+ torch.ops._c10d_functional.all_gather_into_tensor.default
+ aten._to_copy.default
+ aten.view.default
+ aten.equal.default
+ aten.detach.default

With extensibility to add more as required by use cases.

See https://docs.google.com/document/d/16Ptl50mGFJW2cljdF2HQ6FwsiA0scwbAbjx_4dhabJw/edit?usp=drivesdk for more info regarding design and approach.

NOTE: This version of LocalShardsWrapper does not support empty shards, that is added in the next diff enabling CW. D57063512

Test Plan:
` buck test mode/opt -c python.package_style=inplace aiplatform/modelstore/client/tests_gpu:dist_checkpoint_save_load_with_stateful_tests -- --print-passing-details`

`buck2 test 'fbcode//mode/dev-nosan' fbcode//torchrec/distributed/tests:test_tensor_configs -- --print-passing-details`

Sandcastle

Reviewed By: XilunWu, wanchaol

Differential Revision: D58570479

Pull Request resolved: https://github.com/pytorch/pytorch/pull/129150
Approved by: https://github.com/XilunWu
diff --git a/torch/distributed/_checkpointable.py b/torch/distributed/_checkpointable.py
new file mode 100644
index 0000000..409d1f8
--- /dev/null
+++ b/torch/distributed/_checkpointable.py
@@ -0,0 +1,38 @@
+# mypy: allow-untyped-defs
+# Copyright (c) Meta Platforms, Inc. and affiliates
+from typing import Any, Protocol, runtime_checkable
+
+import torch
+
+
+@runtime_checkable
+class _Checkpointable(Protocol):  # noqa: PYI046
+    """
+    Interface for checkpointable objects.
+    Implemented as a protocol, implicit subtyping is supported so subclasses do not need to inherit this explicitly.
+    This is to allow arbitrary objects/tensor subclasses to hook into DCP seamlessly through implementing the interface.
+    """
+
+    def __create_write_items__(self, fqn: str, object: Any):
+        """
+        Return a list of WriteItems based on object's contents.
+        """
+        raise NotImplementedError(
+            "_Checkpointable._create_write_items is not implemented"
+        )
+
+    def __create_chunk_list__(self):
+        """
+        Return a list of `ChunkStorageMetadata` based on object's contents.
+        """
+        raise NotImplementedError(
+            "_Checkpointable._create_chunk_list is not implemented"
+        )
+
+    def __get_tensor_shard__(self, index) -> torch.Tensor:
+        """
+        Return a 'torch.Tensor' shard based on 'MetadataIndex'.
+        """
+        raise NotImplementedError(
+            "_Checkpointable._get_tensor_shard is not implemented"
+        )
diff --git a/torch/distributed/_tensor/_shards_wrapper.py b/torch/distributed/_tensor/_shards_wrapper.py
new file mode 100644
index 0000000..566fad9
--- /dev/null
+++ b/torch/distributed/_tensor/_shards_wrapper.py
@@ -0,0 +1,315 @@
+# mypy: allow-untyped-defs
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Any, List, Tuple
+
+import torch
+from torch.distributed.checkpoint.metadata import (
+    ChunkStorageMetadata,
+    MetadataIndex,
+    TensorProperties,
+    TensorStorageMetadata,
+)
+from torch.distributed.checkpoint.planner import (
+    TensorWriteData,
+    WriteItem,
+    WriteItemType,
+)
+
+aten = (
+    torch.ops.aten
+)  # pyre-ignore[5]: Globally accessible variable `aten` has no type specified.
+
+
+class LocalShardsWrapper(torch.Tensor):  # pyre-ignore[13]: pyre is bad at __new__
+    """
+    A wrapper class to hold local shards of a DTensor.
+    This class is used largely for checkpointing purposes and implicity subtypes
+    the _Checkpointable protocol.
+    """
+
+    __slots__ = ["_local_shards", "_storage_meta"]
+    _local_shards: List[torch.Tensor]
+    _storage_meta: TensorStorageMetadata
+
+    @staticmethod
+    def __new__(
+        cls, local_shards: List[torch.Tensor], local_offsets: List[Tuple[int, ...]]
+    ) -> "LocalShardsWrapper":
+        assert len(local_shards) > 0
+        assert len(local_shards) == len(local_offsets)
+        assert all(
+            tensor.device == local_shards[0].device for tensor in local_shards[1:]
+        )
+
+        # we calculate the total tensor size by "concat" on second tensor dimension
+        cat_tensor_shape = list(local_shards[0].size())
+        if len(local_shards) > 1:  # column-wise sharding
+            for shard in local_shards[1:]:
+                cat_tensor_shape[1] += shard.size()[1]
+
+        wrapper_properties = TensorProperties.create_from_tensor(local_shards[0])
+        wrapper_shape = torch.Size(cat_tensor_shape)
+        chunks_meta = [
+            ChunkStorageMetadata(
+                offsets=torch.Size(offset),
+                sizes=shard.size(),
+            )
+            for shard, offset in zip(local_shards, local_offsets)
+        ]
+
+        r = torch.Tensor._make_wrapper_subclass(  # type: ignore[attr-defined]
+            cls,
+            torch.Size(cat_tensor_shape),
+        )
+        r._local_shards = local_shards
+        r._storage_meta = TensorStorageMetadata(
+            properties=wrapper_properties,
+            size=wrapper_shape,
+            chunks=chunks_meta,
+        )
+
+        return r
+
+    # necessary for ops dispatching from this subclass to its local shards
+    @classmethod
+    # pyre-fixme[3]: Return type must be annotated.
+    # pyre-fixme[2]: Parameter must be annotated.
+    def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
+        kwargs = kwargs or {}
+
+        dispatcher = {
+            torch.ops._c10d_functional.all_gather_into_tensor.default: cls.handle_all_gather_into_tensor,
+            torch.ops._c10d_functional.wait_tensor.default: cls.handle_wait_tensor,
+            aten._to_copy.default: cls.handle_to_copy,
+            aten.view.default: cls.handle_view,
+            aten.equal.default: cls.handle_equal,
+            aten.detach.default: cls.handle_detach,
+            aten.clone.default: cls.handle_clone,
+        }
+
+        if func in dispatcher:
+            return dispatcher[func](
+                args, kwargs
+            )  # pyre-ignore [29] - `Variable[_VT]` is not a function.
+        else:
+            raise NotImplementedError(
+                f"{func} is not supported for LocalShardsWrapper!"
+            )
+
+    @staticmethod
+    # pyre-fixme[3]: Return type must be annotated.
+    # pyre-fixme[2]: Parameter must be annotated.
+    def handle_all_gather_into_tensor(args, kwargs):
+        dim = args[0].local_sizes()[0][1]
+        cat_tensor = torch.cat(
+            [t.view(-1) for t in args[0].local_shards()], dim=0
+        ).view(-1, dim)
+        return torch.ops._c10d_functional.all_gather_into_tensor.default(
+            cat_tensor, *args[1:], **kwargs
+        )
+
+    @staticmethod
+    # pyre-fixme[3]: Return type must be annotated.
+    # pyre-fixme[2]: Parameter must be annotated.
+    def handle_wait_tensor(args, kwargs):
+        return torch.ops._c10d_functional.wait_tensor(args[0])
+
+    @staticmethod
+    # pyre-fixme[3]: Return type must be annotated.
+    # pyre-fixme[2]: Parameter must be annotated.
+    def handle_to_copy(args, kwargs):
+        res_shards_list = [
+            aten._to_copy.default(shard, *args[1:], **kwargs)
+            for shard in args[0].local_shards()
+        ]
+        return LocalShardsWrapper(res_shards_list, args[0].local_offsets())
+
+    @staticmethod
+    # pyre-fixme[3]: Return type must be annotated.
+    # pyre-fixme[2]: Parameter must be annotated.
+    def handle_view(args, kwargs):
+        # TODO, do we need to change the shape of associated offsets?
+        res_shards_list = [
+            aten.view.default(shard, args[1], **kwargs)
+            for shard in args[0].local_shards()
+        ]
+        return LocalShardsWrapper(res_shards_list, args[0].local_offsets())
+
+    @staticmethod
+    # pyre-fixme[3]: Return type must be annotated.
+    # pyre-fixme[2]: Parameter must be annotated.
+    def handle_equal(args, kwargs):
+        """
+        LocalShardsWrapper equal impl also checks for equality of storage metadata
+        and the order of shards
+        """
+        a, b = args[0], args[1]
+        if len(a.local_shards()) != len(b.local_shards()):
+            return False
+        if not all(
+            aten.equal.default(x, y) for x, y in zip(a.local_shards(), b.local_shards())
+        ):
+            return False
+        if not a.storage_metadata() == b.storage_metadata():
+            return False
+        return True
+
+    @staticmethod
+    # pyre-fixme[3]: Return type must be annotated.
+    # pyre-fixme[2]: Parameter must be annotated.
+    def handle_detach(args, kwargs):
+        self_ls = args[0]
+        deatched_local_shards = [
+            aten.detach.default(shard) for shard in self_ls.local_shards()
+        ]
+        self_ls._local_shards = deatched_local_shards
+        self_ls._storage_meta.properties.requires_grad = False
+        return self_ls
+
+    @staticmethod
+    # pyre-fixme[3]: Return type must be annotated.
+    # pyre-fixme[2]: Parameter must be annotated.
+    def handle_clone(args, kwargs):
+        self_ls = args[0]
+        desired_memory_format = kwargs.get("memory_format", None)
+        if desired_memory_format and desired_memory_format != torch.preserve_format:
+            raise NotImplementedError(
+                f"{desired_memory_format} is not supported for LocalShardsWrapper!"
+            )
+        cloned_local_shards = [
+            shard.clone(memory_format=desired_memory_format)
+            for shard in self_ls._local_shards
+        ]
+        return LocalShardsWrapper(cloned_local_shards, self_ls.local_offsets())
+
+    @property
+    def device(self) -> torch._C.device:  # type: ignore[override]
+        return self._local_shards[0].device
+
+    @property
+    def is_meta(self) -> bool:  # type: ignore[override]
+        return self._local_shards[0].is_meta
+
+    # pyre-ignore[14]
+    def is_pinned(self) -> bool:  # type: ignore[override]
+        return self._storage_meta.properties.pin_memory
+
+    # pyre-ignore[14]
+    def requires_grad_(self, requires_grad: bool = True) -> "LocalShardsWrapper":
+        self._storage_meta.properties.requires_grad = requires_grad
+        [shard.requires_grad_(requires_grad) for shard in self._local_shards]
+        return self
+
+    def local_shards(self) -> List[torch.Tensor]:
+        """
+        Returns a list of :class:`torch.Tensor' corresponding to the
+        local shards for this rank. Returns an empty list if the current rank
+        does not host any shards for this Tensor.
+        """
+        return self._local_shards
+
+    def local_sizes(self) -> List[torch.Size]:
+        """
+        Returns a list of :class:`torch.Size' corresponding to the
+        local sizes for the shards on this rank. Returns an empty list if the current rank
+        does not host any shards for this Tensor.
+        """
+        return [chunk.sizes for chunk in self._storage_meta.chunks]
+
+    def local_offsets(self) -> List[torch.Size]:
+        """
+        Returns a list of :class:`torch.Size' corresponding to the
+        local offsets for the shards on this rank. Returns an empty list if the current rank
+        does not host any shards for this Tensor.
+        """
+        return [chunk.offsets for chunk in self._storage_meta.chunks]
+
+    @property
+    def local_chunks(self) -> List[ChunkStorageMetadata]:
+        """
+        Returns a :class:`List[ChunkStorageMetadata]` object corresponding to the
+        metadata for each tensor shard
+        """
+        return self._storage_meta.chunks
+
+    def storage_metadata(self) -> TensorStorageMetadata:
+        """
+        Returns a :class:`TensorStorageMetadata` object corresponding to the
+        metadata for the local tensor on current rank
+        """
+        return self._storage_meta
+
+    def __create_write_items__(
+        self, fqn: str, object: Any
+    ) -> List[WriteItem]:  # pyre-ignore[2]
+        """
+        For compatibility with DCP, we support creation of WriteItems
+        such that they can be saved properly.
+        """
+        return [
+            WriteItem(
+                index=MetadataIndex(fqn, chunks.offsets),
+                type=WriteItemType.SHARD,
+                tensor_data=TensorWriteData(
+                    chunk=ChunkStorageMetadata(
+                        offsets=chunks.offsets,
+                        sizes=chunks.sizes,
+                    ),
+                    properties=self._storage_meta.properties,
+                    size=object.size(),
+                ),
+            )
+            for tensor, chunks in zip(self.local_shards(), self.local_chunks)
+        ]
+
+    def __create_chunk_list__(self) -> List[ChunkStorageMetadata]:
+        """
+        For compatibility with DCP, we support creation of chunk lists
+        such that they can be saved properly.
+        """
+        return self._storage_meta.chunks
+
+    def __get_tensor_shard__(self, index: MetadataIndex) -> torch.Tensor:
+        """
+        For compatibility with DCP, we support finding shard based on index
+        Return a 'torch.Tensor' shard based on 'MetadataIndex'.
+        """
+        # Fast lookup path
+        if index.index is not None:
+            if (
+                len(self._local_shards) > index.index
+                and self._storage_meta.chunks[index.index].offsets == index.offset
+            ):
+                return self._local_shards[index.index]
+
+        if index.offset is not None:
+            for shard, chunk in zip(self._local_shards, self._storage_meta.chunks):
+                if chunk.offsets == index.offset:
+                    return shard
+
+        raise ValueError(
+            f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'"
+        )
+
+    def _get_tensor_size_bytes(self) -> int:
+        object_size = 0
+        for shard in self.local_shards():
+            object_size += shard.nelement() * shard.element_size()
+        return object_size
+
+    # pyre-fixme[3]: Return type must be annotated.
+    def __hash__(self):
+        return id(self)
+
+    # pyre-fixme[14]: `__repr__` overrides method defined in `torch._tensor.Tensor` inconsistently.
+    # pyre-fixme[3]: Return type must be annotated.
+    def __repr__(self):
+        return f"LocalShardsWrapper:{self._local_shards} {self._storage_meta}"
+
+    def __str__(self) -> str:
+        return f"LocalShardsWrapper:{self._local_shards} {self._storage_meta}"
diff --git a/torch/distributed/_tensor/api.py b/torch/distributed/_tensor/api.py
index e1c01040..187b537 100644
--- a/torch/distributed/_tensor/api.py
+++ b/torch/distributed/_tensor/api.py
@@ -532,6 +532,38 @@
         """
         return self._spec.placements
 
+    def __create_write_items__(self, fqn: str, object: Any):
+        from torch.distributed.checkpoint.planner_helpers import (
+            _create_write_items_for_dtensor,
+        )
+
+        if hasattr(self._local_tensor, "__create_write_items__"):
+            return self._local_tensor.__create_write_items__(fqn, object)  # type: ignore[attr-defined]
+        elif isinstance(self._local_tensor, torch.Tensor):
+            return [_create_write_items_for_dtensor(fqn, object)]
+        else:
+            raise RuntimeError("Unsupported tensor type!")
+
+    def __create_chunk_list__(self):
+        from torch.distributed.checkpoint.planner_helpers import (
+            _create_chunk_from_dtensor,
+        )
+
+        if hasattr(self._local_tensor, "__create_chunk_list__"):
+            return self._local_tensor.__create_chunk_list__()  # type: ignore[attr-defined]
+        elif isinstance(self._local_tensor, torch.Tensor):
+            return [_create_chunk_from_dtensor(self)]
+        else:
+            raise RuntimeError("Unsupported tensor type!")
+
+    def __get_tensor_shard__(self, index):
+        if hasattr(self._local_tensor, "__get_tensor_shard__"):
+            return self._local_tensor.__get_tensor_shard__(index)  # type: ignore[attr-defined]
+        elif isinstance(self._local_tensor, torch.Tensor):
+            return self.to_local()
+        else:
+            raise RuntimeError("Unsupported tensor type!")
+
 
 def distribute_tensor(
     tensor: torch.Tensor,
diff --git a/torch/distributed/checkpoint/planner.py b/torch/distributed/checkpoint/planner.py
index d3e7995..c8a0cb2 100644
--- a/torch/distributed/checkpoint/planner.py
+++ b/torch/distributed/checkpoint/planner.py
@@ -425,39 +425,3 @@
         The contents of tensor will follow its device synchronization model.
         """
         pass
-
-
-class _Checkpointable:
-    """
-    Interface for checkpointable objects.
-    This is to allow arbitrary objects/tensor subclasses to hook into DCP seamlessly through implementing the interface.
-    """
-
-    @abc.abstractmethod
-    def _create_write_items(self, fqn: str, object: Any) -> List[WriteItem]:
-        """
-        Return a list of WriteItems based on object's contents.
-        """
-        raise NotImplementedError(
-            "_Checkpointable._create_write_items is not implemented"
-        )
-
-    @abc.abstractmethod
-    def _create_chunk_list(self, tensor: torch.Tensor) -> List[ChunkStorageMetadata]:
-        """
-        Return a list of `ChunkStorageMetadata` based on object's contents.
-        """
-        raise NotImplementedError(
-            "_Checkpointable._create_chunk_list is not implemented"
-        )
-
-    @abc.abstractmethod
-    def _get_tensor_shard(
-        self, tensor: torch.Tensor, index: MetadataIndex
-    ) -> torch.Tensor:
-        """
-        Return a 'torch.Tensor' shard based on 'MetadataIndex'.
-        """
-        raise NotImplementedError(
-            "_Checkpointable._get_tensor_shard is not implemented"
-        )
diff --git a/torch/distributed/checkpoint/planner_helpers.py b/torch/distributed/checkpoint/planner_helpers.py
index 56e1728..1225271 100644
--- a/torch/distributed/checkpoint/planner_helpers.py
+++ b/torch/distributed/checkpoint/planner_helpers.py
@@ -8,7 +8,6 @@
 from torch.distributed._shard.sharded_tensor import ShardedTensor
 from torch.distributed._tensor import DTensor
 from torch.distributed._tensor._utils import compute_local_shape_and_global_offset
-from torch.distributed.checkpoint.planner import _Checkpointable
 from torch.utils._pytree import tree_map_only
 
 from .metadata import (
@@ -218,13 +217,9 @@
 
 
 def _create_write_items(fqn: str, object: Any) -> List[WriteItem]:
-    if isinstance(object, _Checkpointable):
-        return object._create_write_items(fqn, object)
-    elif isinstance(object, DTensor):
-        # DTensor can contain a local tensor that is a tensor subclass
-        if isinstance(object.to_local(), _Checkpointable):
-            return object.to_local()._create_write_items(fqn, object)  # type: ignore[arg-type]
-        return [_create_write_items_for_dtensor(fqn, object)]
+    if hasattr(object, "__create_write_items__"):
+        # DTensor implements _Checkpointable
+        return object.__create_write_items__(fqn, object)
     elif isinstance(object, ShardedTensor):
         return [
             _create_write_item_for_shard(fqn, object, shard.metadata)
@@ -248,13 +243,9 @@
 
 
 def _create_chunk_list(tensor: torch.Tensor) -> List[ChunkStorageMetadata]:
-    if isinstance(tensor, _Checkpointable):
-        local_chunks = tensor._create_chunk_list(tensor)
-    elif isinstance(tensor, DTensor):
-        # DTensor can contain a local tensor that is a tensor subclass
-        if isinstance(tensor.to_local(), _Checkpointable):
-            return tensor.to_local()._create_chunk_list(tensor)  # type: ignore[arg-type]
-        local_chunks = [_create_chunk_from_dtensor(tensor)]
+    if hasattr(tensor, "__create_chunk_list__"):
+        # DTensor implements _Checkpointable
+        local_chunks = tensor.__create_chunk_list__()  # type: ignore[attr-defined]
     elif isinstance(tensor, ShardedTensor):
         local_chunks = [
             _chunk_for_shard(shard.metadata) for shard in tensor.local_shards()
diff --git a/torch/distributed/checkpoint/utils.py b/torch/distributed/checkpoint/utils.py
index 3264945..5bc171a 100644
--- a/torch/distributed/checkpoint/utils.py
+++ b/torch/distributed/checkpoint/utils.py
@@ -14,8 +14,6 @@
 import torch.distributed as dist
 from torch.distributed._shard.sharded_tensor import ShardedTensor
 from torch.distributed._shard.sharded_tensor.shard import Shard
-from torch.distributed._tensor import DTensor
-from torch.distributed.checkpoint.planner import _Checkpointable
 
 from .api import (
     _is_wrapped_exception,
@@ -304,13 +302,9 @@
 
 
 def find_tensor_shard(tensor: torch.Tensor, index: MetadataIndex) -> torch.Tensor:
-    if isinstance(tensor, _Checkpointable):
-        return tensor._get_tensor_shard(tensor, index)
-    elif isinstance(tensor, DTensor):
-        # DTensor can contain a local tensor that is a tensor subclass
-        if isinstance(tensor.to_local(), _Checkpointable):
-            return tensor.to_local()._get_tensor_shard(tensor, index)  # type: ignore[arg-type]
-        return tensor.to_local()
+    if hasattr(tensor, "__get_tensor_shard__"):
+        # DTensor implements _Checkpointable
+        return tensor.__get_tensor_shard__(index)  # type: ignore[attr-defined]
     if isinstance(tensor, ShardedTensor):
         return _find_shard(tensor, index).tensor
     if index.offset is not None: