| from typing import List, Any |
| |
| import torch |
| |
| from torch.distributed._shard.metadata import ShardMetadata |
| from torch.distributed._shard.sharded_tensor import ShardedTensor |
| from torch.distributed._shard.sharded_tensor.metadata import TensorProperties |
| from torch.distributed._shard.sharded_tensor.shard import Shard |
| |
| from torch.distributed._shard.sharding_spec._internals import ( |
| _check_shard_metadata_pair_overlap, |
| ) |
| |
| from .planner import ( |
| LoadItemType, |
| SavePlan, |
| ReadItem, |
| WriteItem, |
| WriteItemType, |
| TensorWriteData, |
| ) |
| |
| from .metadata import ( |
| BytesStorageMetadata, |
| ChunkStorageMetadata, |
| TensorStorageMetadata, |
| MetadataIndex, |
| STATE_DICT_TYPE, |
| STORAGE_TYPES, |
| ) |
| |
| from .resharding import _shards_get_overlap_region_wrt_saved_tensor |
| |
| __all__: List[str] = [] |
| |
| |
| def _create_shard_metadata(size: torch.Size) -> ShardMetadata: |
| return ShardMetadata( |
| shard_offsets=[0] * len(size), |
| shard_sizes=list(size), |
| ) |
| |
| |
| def _create_shard_from_tensor(tensor: torch.Tensor) -> Shard: |
| return Shard(tensor=tensor, metadata=_create_shard_metadata(tensor.size())) |
| |
| |
| def _chunk_for_shard(shard_md: ShardMetadata) -> ChunkStorageMetadata: |
| return ChunkStorageMetadata( |
| offsets=torch.Size(shard_md.shard_offsets), |
| sizes=torch.Size(shard_md.shard_sizes), |
| ) |
| |
| |
| def _sharded_tensor_metadata( |
| sharded_tensor: ShardedTensor, shard_md: ShardMetadata |
| ) -> TensorWriteData: |
| return TensorWriteData( |
| chunk=_chunk_for_shard(shard_md), |
| properties=sharded_tensor.metadata().tensor_properties, |
| size=sharded_tensor.metadata().size, |
| ) |
| |
| |
| def _create_write_item_for_shard( |
| fqn: str, sharded_tensor: ShardedTensor, shard_md: ShardMetadata |
| ) -> WriteItem: |
| offsets = torch.Size(shard_md.shard_offsets) |
| return WriteItem( |
| index=MetadataIndex(fqn, offsets), |
| type=WriteItemType.SHARD, |
| tensor_data=_sharded_tensor_metadata(sharded_tensor, shard_md), |
| ) |
| |
| |
| def _create_write_item_for_tensor(fqn: str, tensor: torch.Tensor) -> WriteItem: |
| offsets = torch.Size([0] * len(tensor.size())) |
| return WriteItem( |
| index=MetadataIndex(fqn, offsets), |
| type=WriteItemType.TENSOR, |
| tensor_data=TensorWriteData( |
| chunk=ChunkStorageMetadata(offsets=offsets, sizes=tensor.size()), |
| properties=TensorProperties.create_from_tensor(tensor), |
| size=tensor.size(), |
| ), |
| ) |
| |
| |
| def _create_write_item_for_bytesio(fqn: str, bytes: Any): |
| return WriteItem( |
| index=MetadataIndex(fqn), |
| type=WriteItemType.BYTE_IO, |
| ) |
| |
| |
| def _create_read_item_for_byteio( |
| dest_index, dest_offset, storage_index, storage_offset, length |
| ): |
| return ReadItem( |
| type=LoadItemType.BYTE_IO, |
| dest_index=dest_index, |
| dest_offsets=torch.Size((dest_offset,)), |
| storage_index=storage_index, |
| storage_offsets=torch.Size((storage_offset,)), |
| lengths=torch.Size((length,)), |
| ) |
| |
| |
| def _create_read_item_for_tensor( |
| dest_index, dest_offsets, storage_index, storage_offsets, lengths |
| ): |
| return ReadItem( |
| type=LoadItemType.TENSOR, |
| dest_index=dest_index, |
| dest_offsets=torch.Size(dest_offsets), |
| storage_index=storage_index, |
| storage_offsets=torch.Size(storage_offsets), |
| lengths=torch.Size(lengths), |
| ) |
| |
| |
| def _create_sharded_read_items( |
| fqn: str, |
| checkpoint_md: TensorStorageMetadata, |
| local_shards: List[Shard], |
| ) -> List[ReadItem]: |
| |
| read_items = [] |
| # this is a naive quadratic algo that can be optimized later |
| for idx, shard in enumerate(local_shards): |
| for storage_idx, storage_md in enumerate(checkpoint_md.chunks): |
| shard_md_from_storage = ShardMetadata( |
| shard_sizes=list(storage_md.sizes), |
| shard_offsets=list(storage_md.offsets), |
| ) |
| |
| if not _check_shard_metadata_pair_overlap( |
| shard.metadata, shard_md_from_storage |
| ): |
| continue |
| |
| storage_offsets = [] |
| dest_offsets = [] |
| lengths = [] |
| for ( |
| dim, |
| offset_for_saved_tensor, |
| offset_for_current_tensor, |
| length, |
| ) in _shards_get_overlap_region_wrt_saved_tensor( |
| saved_shard=shard_md_from_storage, current_shard=shard.metadata |
| ): |
| storage_offsets.append(offset_for_saved_tensor) |
| dest_offsets.append(offset_for_current_tensor) |
| lengths.append(length) |
| |
| read_items.append( |
| _create_read_item_for_tensor( |
| dest_index=MetadataIndex( |
| fqn, shard.metadata.shard_offsets, idx |
| ), |
| dest_offsets=dest_offsets, |
| storage_index=MetadataIndex( |
| fqn, storage_md.offsets, storage_idx |
| ), |
| storage_offsets=storage_offsets, |
| lengths=lengths, |
| ) |
| ) |
| return read_items |
| |
| |
| def _create_default_metadata_only_plan(state_dict: STATE_DICT_TYPE) -> SavePlan: |
| requests = [] |
| for fqn, obj in state_dict.items(): |
| if isinstance(obj, ShardedTensor): |
| for shard_md in obj.metadata().shards_metadata: |
| requests.append( |
| _create_write_item_for_shard(fqn, obj, shard_md) |
| ) |
| elif isinstance(obj, torch.Tensor): |
| requests.append(_create_write_item_for_tensor(fqn, obj)) |
| else: |
| requests.append(_create_write_item_for_bytesio(fqn, obj)) |
| return SavePlan(requests) |
| |
| |
| def _create_write_items(fqn: str, object: Any) -> List[WriteItem]: |
| if isinstance(object, ShardedTensor): |
| return [ |
| _create_write_item_for_shard(fqn, object, shard.metadata) |
| for shard in object.local_shards() |
| ] |
| elif isinstance(object, torch.Tensor): |
| return [_create_write_item_for_tensor(fqn, object)] |
| else: |
| return [_create_write_item_for_bytesio(fqn, object)] |
| |
| |
| def _create_read_items(fqn: str, md: STORAGE_TYPES, obj: Any) -> List[ReadItem]: |
| if isinstance(md, BytesStorageMetadata): |
| return [ |
| _create_read_item_for_byteio( |
| dest_index=MetadataIndex(fqn), |
| dest_offset=0, |
| storage_index=MetadataIndex(fqn), |
| storage_offset=0, |
| length=0, |
| ) |
| ] |
| elif isinstance(obj, ShardedTensor): |
| local_shards = obj.local_shards() |
| elif isinstance(obj, torch.Tensor): |
| local_shards = [_create_shard_from_tensor(obj)] |
| else: |
| raise ValueError( |
| f"Invalid checkpoint metadata for {fqn}, " |
| + f"expected BytesStorageMetadata but found {type(md)}" |
| ) |
| |
| return _create_sharded_read_items(fqn, md, local_shards) |