Use dataclasses to simplify ShardingSpec (#58893)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58893

Leverage dataclasses to simplify some of the ShardingSpec classes.
ghstack-source-id: 130041687

Test Plan: waitforbuildbot

Reviewed By: SciPioneer

Differential Revision: D28665137

fbshipit-source-id: da37517cf2bd8c65d4a5b7cae171fa460e6b0946
diff --git a/torch/distributed/_sharded_tensor/api.py b/torch/distributed/_sharded_tensor/api.py
index 89125cd..38bea61 100644
--- a/torch/distributed/_sharded_tensor/api.py
+++ b/torch/distributed/_sharded_tensor/api.py
@@ -1,3 +1,4 @@
+from dataclasses import dataclass
 from typing import List
 
 import torch
@@ -11,24 +12,16 @@
 from torch.distributed.utils import _parse_remote_device
 
 
+@dataclass
 class Shard(object):
     """
     Container which holds the data for a shard as a Tensor and also
     the associated metadata for that shard.
     """
-    __slots__ = ['_tensor', '_metadata']
+    __slots__ = ['tensor', 'metadata']
 
-    def __init__(self, tensor: torch.Tensor, metadata: ShardMetadata):
-        self._tensor = tensor
-        self._metadata = metadata
-
-    @property
-    def tensor(self) -> torch.Tensor:
-        return self._tensor
-
-    @property
-    def metadata(self) -> ShardMetadata:
-        return self._metadata
+    tensor: torch.Tensor
+    metadata: ShardMetadata
 
 
 class ShardedTensor(object):
diff --git a/torch/distributed/_sharding_spec/api.py b/torch/distributed/_sharding_spec/api.py
index 1bc1219..451cb23 100644
--- a/torch/distributed/_sharding_spec/api.py
+++ b/torch/distributed/_sharding_spec/api.py
@@ -1,4 +1,5 @@
 from abc import ABC
+from dataclasses import dataclass
 import torch
 from typing import List, Union
 
@@ -14,6 +15,8 @@
     """
     pass
 
+
+@dataclass
 class DevicePlacementSpec(PlacementSpec):
     """
     Associates placement of an entity with a single device. The device can be a
@@ -26,18 +29,12 @@
     Args:
         device(str, :class:`torch.device`): The device to place the entity on.
     """
-    def __init__(self, device: Device):
-        super(DevicePlacementSpec, self).__init__()
-        if not is_valid_device(device):
-            raise ValueError(f'{device} is not a valid device')
-        self._device = device
 
-    @property
-    def device(self) -> Device:
-        """
-        Retrieves the device for placement.
-        """
-        return self._device
+    device: Device
+
+    def __post_init__(self):
+        if not is_valid_device(self.device):
+            raise ValueError(f'{self.device} is not a valid device')
 
 
 class ShardingSpec(PlacementSpec):
@@ -48,6 +45,7 @@
     pass
 
 
+@dataclass
 class ChunkShardingSpec(ShardingSpec):
     """
     This is a type of PlacementSpec that defines the placement as being sharded
@@ -81,12 +79,12 @@
     ShardingDim = Union[int, str]
     ShardPlacements = List[Union[Device, PlacementSpec]]
 
-    def __init__(self, dim: ShardingDim, placements: ShardPlacements):
-        super(ChunkShardingSpec, self).__init__()
-        self._verify_dim(dim)
-        self._verify_devices(placements)
-        self._dim = dim
-        self._placements = placements
+    dim: ShardingDim
+    placements: ShardPlacements
+
+    def __post_init__(self):
+        self._verify_dim(self.dim)
+        self._verify_devices(self.placements)
 
     @staticmethod
     def _verify_devices(placements):
@@ -101,20 +99,8 @@
         if not (isinstance(dim, int) or isinstance(dim, str)):
             raise ValueError(f'{dim} needs to either be an int or str')
 
-    @property
-    def dim(self) -> ShardingDim:
-        """
-        Retrieves the dimension to shard on.
-        """
-        return self._dim
 
-    @property
-    def placements(self) -> ShardPlacements:
-        """
-        Retrieves the shard placements.
-        """
-        return self._placements
-
+@dataclass
 class ShardMetadata(object):
     """
     Represents a shard of the overall Tensor including its
@@ -143,77 +129,54 @@
 
     ShardPlacement = Union[Device, PlacementSpec]
 
-    __slots__ = ['_shard_offsets', '_shard_lengths', '_placement']
+    __slots__ = ['shard_offsets', 'shard_lengths', 'placement']
 
-    def __init__(
-            self,
-            shard_offsets: List[int],
-            shard_lengths: List[int],
-            placement: ShardPlacement):
+    shard_offsets: List[int]
+    shard_lengths: List[int]
+    placement: ShardPlacement
 
-        if not isinstance(placement, PlacementSpec) and not is_valid_device(placement):
-            raise ValueError(f'{placement} is not a valid device')
+    def __post_init__(self):
+        if not isinstance(self.placement, PlacementSpec) and not is_valid_device(self.placement):
+            raise ValueError(f'{self.placement} is not a valid device')
 
-        if len(shard_offsets) != len(shard_lengths):
+        if len(self.shard_offsets) != len(self.shard_lengths):
             raise ValueError(
                 f'shard_offsets and shard_lengths should have '
-                f'the same number of elements, found {len(shard_offsets)} '
-                f'and {shard_lengths} respectively')
+                f'the same number of elements, found {len(self.shard_offsets)} '
+                f'and {self.shard_lengths} respectively')
 
-        for i in range(len(shard_offsets)):
-            if shard_offsets[i] < 0:
+        for i in range(len(self.shard_offsets)):
+            if self.shard_offsets[i] < 0:
                 raise ValueError('shard_offsets should be >=0')
-            if shard_lengths[i] <= 0:
+            if self.shard_lengths[i] <= 0:
                 raise ValueError('shard_lengths should be > 0')
 
-        self._shard_offsets = shard_offsets
-        self._shard_lengths = shard_lengths
-        self._placement = placement
 
-    def __repr__(self):
-        return (
-            f'ShardMetadata(shard_offsets: {self._shard_offsets}, '
-            f'shard_lengths: {self._shard_lengths}, placement: {self._placement})'
-        )
-
-    @property
-    def shard_offsets(self):
-        return self._shard_offsets
-
-    @property
-    def shard_lengths(self):
-        return self._shard_lengths
-
-    @property
-    def placement(self):
-        return self._placement
-
-
+@dataclass
 class EnumerableShardingSpec(ShardingSpec):
+    """
+    This is a type of PlacementSpec that allows users to specify a generic
+    sharding scheme by enumerating exactly how each shard is laid out.
 
-    def __init__(self, shards: List[ShardMetadata]):
-        """
-        This is a type of PlacementSpec that allows users to specify a generic
-        sharding scheme by enumerating exactly how each shard is laid out.
+    Args:
+        shards(List[ShardMetadata]): List of :class:`ShardMetadata` objects representing
+            each shard.
+    """
 
-        Args:
-            shards(List[ShardMetadata]): List of :class:`ShardMetadata` objects representing
-                each shard.
-        """
-        super(EnumerableShardingSpec, self).__init__()
-        if len(shards) == 0:
-            raise ValueError(f'Empty shard list provided: {shards}')
+    shards: List[ShardMetadata]
+
+    def __post_init__(self):
+        if len(self.shards) == 0:
+            raise ValueError(f'Empty shard list provided: {self.shards}')
 
         # Validate each shard has same rank.
         rank = -1
-        for shard in shards:
+        for shard in self.shards:
             if rank != -1 and rank != len(shard.shard_offsets):
                 raise ValueError(f'Found inconsistent ranks for shards: {rank} and {len(shard.shard_offsets)}')
             rank = len(shard.shard_offsets)
 
-        self._validate_non_overlapping(shards)
-
-        self._shards = shards
+        self._validate_non_overlapping(self.shards)
 
     @staticmethod
     def _validate_non_overlapping(shards: List[ShardMetadata]):
@@ -245,10 +208,6 @@
 
         return True
 
-    @property
-    def shards(self):
-        return self._shards
-
     def check_tensor(self, tensor: torch.Tensor) -> None:
         """
         Checks if the sharding spec is compatible with the provided tensor.
@@ -264,13 +223,13 @@
         # sharding spec for this tensor. Note that we have already verified
         # we don't have overlapping shards.
         tensor_rank = len(tensor.size())
-        shards_rank = len(self._shards[0].shard_offsets)
+        shards_rank = len(self.shards[0].shard_offsets)
         if tensor_rank != shards_rank:
             raise ValueError(f'Rank of tensor is {tensor_rank}, but shards rank is {shards_rank}')
 
         total_shard_volume = 0
         tensor_dims = tensor.size()
-        for shard in self._shards:
+        for shard in self.shards:
             shard_volume = 1
             for i, shard_length in enumerate(shard.shard_lengths):
                 shard_volume *= shard_length