Use dataclasses to simplify ShardingSpec (#58893)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58893
Leverage dataclasses to simplify some of the ShardingSpec classes.
ghstack-source-id: 130041687
Test Plan: waitforbuildbot
Reviewed By: SciPioneer
Differential Revision: D28665137
fbshipit-source-id: da37517cf2bd8c65d4a5b7cae171fa460e6b0946
diff --git a/torch/distributed/_sharded_tensor/api.py b/torch/distributed/_sharded_tensor/api.py
index 89125cd..38bea61 100644
--- a/torch/distributed/_sharded_tensor/api.py
+++ b/torch/distributed/_sharded_tensor/api.py
@@ -1,3 +1,4 @@
+from dataclasses import dataclass
from typing import List
import torch
@@ -11,24 +12,16 @@
from torch.distributed.utils import _parse_remote_device
+@dataclass
class Shard(object):
"""
Container which holds the data for a shard as a Tensor and also
the associated metadata for that shard.
"""
- __slots__ = ['_tensor', '_metadata']
+ __slots__ = ['tensor', 'metadata']
- def __init__(self, tensor: torch.Tensor, metadata: ShardMetadata):
- self._tensor = tensor
- self._metadata = metadata
-
- @property
- def tensor(self) -> torch.Tensor:
- return self._tensor
-
- @property
- def metadata(self) -> ShardMetadata:
- return self._metadata
+ tensor: torch.Tensor
+ metadata: ShardMetadata
class ShardedTensor(object):
diff --git a/torch/distributed/_sharding_spec/api.py b/torch/distributed/_sharding_spec/api.py
index 1bc1219..451cb23 100644
--- a/torch/distributed/_sharding_spec/api.py
+++ b/torch/distributed/_sharding_spec/api.py
@@ -1,4 +1,5 @@
from abc import ABC
+from dataclasses import dataclass
import torch
from typing import List, Union
@@ -14,6 +15,8 @@
"""
pass
+
+@dataclass
class DevicePlacementSpec(PlacementSpec):
"""
Associates placement of an entity with a single device. The device can be a
@@ -26,18 +29,12 @@
Args:
device(str, :class:`torch.device`): The device to place the entity on.
"""
- def __init__(self, device: Device):
- super(DevicePlacementSpec, self).__init__()
- if not is_valid_device(device):
- raise ValueError(f'{device} is not a valid device')
- self._device = device
- @property
- def device(self) -> Device:
- """
- Retrieves the device for placement.
- """
- return self._device
+ device: Device
+
+ def __post_init__(self):
+ if not is_valid_device(self.device):
+ raise ValueError(f'{self.device} is not a valid device')
class ShardingSpec(PlacementSpec):
@@ -48,6 +45,7 @@
pass
+@dataclass
class ChunkShardingSpec(ShardingSpec):
"""
This is a type of PlacementSpec that defines the placement as being sharded
@@ -81,12 +79,12 @@
ShardingDim = Union[int, str]
ShardPlacements = List[Union[Device, PlacementSpec]]
- def __init__(self, dim: ShardingDim, placements: ShardPlacements):
- super(ChunkShardingSpec, self).__init__()
- self._verify_dim(dim)
- self._verify_devices(placements)
- self._dim = dim
- self._placements = placements
+ dim: ShardingDim
+ placements: ShardPlacements
+
+ def __post_init__(self):
+ self._verify_dim(self.dim)
+ self._verify_devices(self.placements)
@staticmethod
def _verify_devices(placements):
@@ -101,20 +99,8 @@
if not (isinstance(dim, int) or isinstance(dim, str)):
raise ValueError(f'{dim} needs to either be an int or str')
- @property
- def dim(self) -> ShardingDim:
- """
- Retrieves the dimension to shard on.
- """
- return self._dim
- @property
- def placements(self) -> ShardPlacements:
- """
- Retrieves the shard placements.
- """
- return self._placements
-
+@dataclass
class ShardMetadata(object):
"""
Represents a shard of the overall Tensor including its
@@ -143,77 +129,54 @@
ShardPlacement = Union[Device, PlacementSpec]
- __slots__ = ['_shard_offsets', '_shard_lengths', '_placement']
+ __slots__ = ['shard_offsets', 'shard_lengths', 'placement']
- def __init__(
- self,
- shard_offsets: List[int],
- shard_lengths: List[int],
- placement: ShardPlacement):
+ shard_offsets: List[int]
+ shard_lengths: List[int]
+ placement: ShardPlacement
- if not isinstance(placement, PlacementSpec) and not is_valid_device(placement):
- raise ValueError(f'{placement} is not a valid device')
+ def __post_init__(self):
+ if not isinstance(self.placement, PlacementSpec) and not is_valid_device(self.placement):
+ raise ValueError(f'{self.placement} is not a valid device')
- if len(shard_offsets) != len(shard_lengths):
+ if len(self.shard_offsets) != len(self.shard_lengths):
raise ValueError(
f'shard_offsets and shard_lengths should have '
- f'the same number of elements, found {len(shard_offsets)} '
- f'and {shard_lengths} respectively')
+ f'the same number of elements, found {len(self.shard_offsets)} '
+ f'and {self.shard_lengths} respectively')
- for i in range(len(shard_offsets)):
- if shard_offsets[i] < 0:
+ for i in range(len(self.shard_offsets)):
+ if self.shard_offsets[i] < 0:
raise ValueError('shard_offsets should be >=0')
- if shard_lengths[i] <= 0:
+ if self.shard_lengths[i] <= 0:
raise ValueError('shard_lengths should be > 0')
- self._shard_offsets = shard_offsets
- self._shard_lengths = shard_lengths
- self._placement = placement
- def __repr__(self):
- return (
- f'ShardMetadata(shard_offsets: {self._shard_offsets}, '
- f'shard_lengths: {self._shard_lengths}, placement: {self._placement})'
- )
-
- @property
- def shard_offsets(self):
- return self._shard_offsets
-
- @property
- def shard_lengths(self):
- return self._shard_lengths
-
- @property
- def placement(self):
- return self._placement
-
-
+@dataclass
class EnumerableShardingSpec(ShardingSpec):
+ """
+ This is a type of PlacementSpec that allows users to specify a generic
+ sharding scheme by enumerating exactly how each shard is laid out.
- def __init__(self, shards: List[ShardMetadata]):
- """
- This is a type of PlacementSpec that allows users to specify a generic
- sharding scheme by enumerating exactly how each shard is laid out.
+ Args:
+ shards(List[ShardMetadata]): List of :class:`ShardMetadata` objects representing
+ each shard.
+ """
- Args:
- shards(List[ShardMetadata]): List of :class:`ShardMetadata` objects representing
- each shard.
- """
- super(EnumerableShardingSpec, self).__init__()
- if len(shards) == 0:
- raise ValueError(f'Empty shard list provided: {shards}')
+ shards: List[ShardMetadata]
+
+ def __post_init__(self):
+ if len(self.shards) == 0:
+ raise ValueError(f'Empty shard list provided: {self.shards}')
# Validate each shard has same rank.
rank = -1
- for shard in shards:
+ for shard in self.shards:
if rank != -1 and rank != len(shard.shard_offsets):
raise ValueError(f'Found inconsistent ranks for shards: {rank} and {len(shard.shard_offsets)}')
rank = len(shard.shard_offsets)
- self._validate_non_overlapping(shards)
-
- self._shards = shards
+ self._validate_non_overlapping(self.shards)
@staticmethod
def _validate_non_overlapping(shards: List[ShardMetadata]):
@@ -245,10 +208,6 @@
return True
- @property
- def shards(self):
- return self._shards
-
def check_tensor(self, tensor: torch.Tensor) -> None:
"""
Checks if the sharding spec is compatible with the provided tensor.
@@ -264,13 +223,13 @@
# sharding spec for this tensor. Note that we have already verified
# we don't have overlapping shards.
tensor_rank = len(tensor.size())
- shards_rank = len(self._shards[0].shard_offsets)
+ shards_rank = len(self.shards[0].shard_offsets)
if tensor_rank != shards_rank:
raise ValueError(f'Rank of tensor is {tensor_rank}, but shards rank is {shards_rank}')
total_shard_volume = 0
tensor_dims = tensor.size()
- for shard in self._shards:
+ for shard in self.shards:
shard_volume = 1
for i, shard_length in enumerate(shard.shard_lengths):
shard_volume *= shard_length