[DCP][BE] Apply ufmt to DCP and turn on lintrunner for DCP (#115302)
No logic change. Just typing and ufmt.
Differential Revision: [D51914982](https://our.internmc.facebook.com/intern/diff/D51914982/)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/115302
Approved by: https://github.com/XilunWu, https://github.com/wz337, https://github.com/LucasLLC
ghstack dependencies: #115523
diff --git a/.lintrunner.toml b/.lintrunner.toml
index ae54cfc..4a9ff17 100644
--- a/.lintrunner.toml
+++ b/.lintrunner.toml
@@ -869,7 +869,6 @@
'test/bottleneck_test/**', # excluded by test/run_test.py
'test/distributed/argparse_util_test.py',
'test/distributed/bin/test_script.py',
- 'test/distributed/checkpoint/e2e/test_pipeline.py',
'test/distributed/elastic/agent/server/test/local_elastic_agent_test.py',
'test/distributed/elastic/multiprocessing/bin/test_script.py',
'test/distributed/elastic/multiprocessing/bin/zombie_test.py',
@@ -1094,19 +1093,6 @@
'test/distributed/algorithms/test_join.py',
'test/distributed/argparse_util_test.py',
'test/distributed/bin/test_script.py',
- 'test/distributed/checkpoint/test_2d_fsdp_dt_checkpoint.py',
- 'test/distributed/checkpoint/test_checkpoint.py',
- 'test/distributed/checkpoint/test_dedup_tensors.py',
- 'test/distributed/checkpoint/test_dtensor_checkpoint.py',
- 'test/distributed/checkpoint/test_file_system_checkpoint.py',
- 'test/distributed/checkpoint/test_file_system_checkpoint_cpu.py',
- 'test/distributed/checkpoint/test_fsdp_model_state.py',
- 'test/distributed/checkpoint/test_fsdp_optim_state.py',
- 'test/distributed/checkpoint/test_fsspec.py',
- 'test/distributed/checkpoint/test_nested_dict.py',
- 'test/distributed/checkpoint/test_planner.py',
- 'test/distributed/checkpoint/test_traverse.py',
- 'test/distributed/checkpoint/test_utils.py',
'test/distributed/elastic/agent/server/test/__init__.py',
'test/distributed/elastic/agent/server/test/api_test.py',
'test/distributed/elastic/agent/server/test/local_elastic_agent_test.py',
@@ -2010,25 +1996,6 @@
'torch/distributed/autograd/__init__.py',
'torch/distributed/benchmarks/benchmark_ddp_rpc.py',
'torch/distributed/c10d_logger.py',
- 'torch/distributed/checkpoint/__init__.py',
- 'torch/distributed/checkpoint/_dedup_tensors.py',
- 'torch/distributed/checkpoint/_fsspec_filesystem.py',
- 'torch/distributed/checkpoint/_nested_dict.py',
- 'torch/distributed/checkpoint/_sharded_tensor_utils.py',
- 'torch/distributed/checkpoint/_traverse.py',
- 'torch/distributed/checkpoint/api.py',
- 'torch/distributed/checkpoint/default_planner.py',
- 'torch/distributed/checkpoint/examples/fsdp_checkpoint_example.py',
- 'torch/distributed/checkpoint/filesystem.py',
- 'torch/distributed/checkpoint/metadata.py',
- 'torch/distributed/checkpoint/optimizer.py',
- 'torch/distributed/checkpoint/planner.py',
- 'torch/distributed/checkpoint/planner_helpers.py',
- 'torch/distributed/checkpoint/resharding.py',
- 'torch/distributed/checkpoint/state_dict_loader.py',
- 'torch/distributed/checkpoint/state_dict_saver.py',
- 'torch/distributed/checkpoint/storage.py',
- 'torch/distributed/checkpoint/utils.py',
'torch/distributed/collective_utils.py',
'torch/distributed/constants.py',
'torch/distributed/distributed_c10d.py',
@@ -2442,7 +2409,6 @@
'torch/testing/_internal/distributed/_shard/test_common.py',
'torch/testing/_internal/distributed/_tensor/__init__.py',
'torch/testing/_internal/distributed/_tensor/common_dtensor.py',
- 'torch/testing/_internal/distributed/checkpoint_utils.py',
'torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py',
'torch/testing/_internal/distributed/distributed_test.py',
'torch/testing/_internal/distributed/distributed_utils.py',
diff --git a/test/distributed/checkpoint/e2e/test_pipeline.py b/test/distributed/checkpoint/e2e/test_pipeline.py
index cbf60b9..e197058 100644
--- a/test/distributed/checkpoint/e2e/test_pipeline.py
+++ b/test/distributed/checkpoint/e2e/test_pipeline.py
@@ -10,7 +10,7 @@
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
from torch.testing._internal.common_fsdp import FSDPTest
-from torch.testing._internal.common_utils import TEST_WITH_DEV_DBG_ASAN
+from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
@@ -102,3 +102,7 @@
self.assertTrue(os.path.exists(pipeline_dir))
self.save_with_pipeline(pipeline_dir)
self.load_with_fsdp(pipeline_dir)
+
+
+if __name__ == "__main__":
+ run_tests()
diff --git a/test/distributed/checkpoint/test_checkpoint.py b/test/distributed/checkpoint/test_checkpoint.py
index a2c9a29a..937853e 100644
--- a/test/distributed/checkpoint/test_checkpoint.py
+++ b/test/distributed/checkpoint/test_checkpoint.py
@@ -1,29 +1,28 @@
# Owner(s): ["oncall: distributed"]
import sys
-from typing import Optional, List, cast
-from torch.distributed.checkpoint.storage import WriteResult
-
-from torch.distributed.checkpoint import (
- StorageReader,
- StorageWriter,
- CheckpointException,
- load_state_dict,
- save_state_dict,
-)
+from typing import cast, List, Optional
import torch
import torch.distributed as dist
-import torch.nn
import torch.futures
-from torch.futures import Future
+import torch.nn
from torch.distributed._shard import sharded_tensor
-from torch.distributed.checkpoint.default_planner import (
- _create_default_local_metadata,
+from torch.distributed._shard.sharded_tensor import ShardedTensor, state_dict_hook
+from torch.distributed._shard.sharding_spec import ChunkShardingSpec
+
+from torch.distributed.checkpoint import (
+ CheckpointException,
+ load_state_dict,
+ save_state_dict,
+ StorageReader,
+ StorageWriter,
)
+from torch.distributed.checkpoint.default_planner import _create_default_local_metadata
+
from torch.distributed.checkpoint.metadata import (
BytesStorageMetadata,
Metadata,
@@ -31,31 +30,21 @@
)
from torch.distributed.checkpoint.planner import (
- SavePlan,
- SavePlanner,
LoadPlan,
LoadPlanner,
+ SavePlan,
+ SavePlanner,
)
+from torch.distributed.checkpoint.storage import WriteResult
+from torch.futures import Future
+from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
-from torch.distributed._shard.sharded_tensor import (
- state_dict_hook,
- ShardedTensor,
-)
-from torch.distributed._shard.sharding_spec import ChunkShardingSpec
-from torch.testing._internal.common_distributed import (
- requires_nccl,
- skip_if_lt_x_gpu,
-)
+from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
from torch.testing._internal.distributed._shard.sharded_tensor import (
ShardedTensorTestBase,
with_comms,
)
-from torch.testing._internal.common_utils import (
- TEST_WITH_DEV_DBG_ASAN,
- run_tests,
-)
-
if TEST_WITH_DEV_DBG_ASAN:
print(
"Skip dev-asan as torch + multiprocessing spawn have known issues",
@@ -175,9 +164,7 @@
ranks = self._get_ranks(name)
fut = Future()
if ranks is not None and self.rank in ranks:
- fut.set_exception(
- ValueError(f"async rank fail {self.rank} for {name}")
- )
+ fut.set_exception(ValueError(f"async rank fail {self.rank} for {name}"))
else:
fut.set_result(result)
return fut
@@ -204,9 +191,7 @@
self._fail_rank("fail_write_data")
return self._fail_rank_async("fail_write_data_async", [])
- def finish(
- self, metadata: Metadata, results: List[List[WriteResult]]
- ) -> None:
+ def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
self._fail_rank("fail_finish")
@@ -239,9 +224,7 @@
def get_spec(self):
return ChunkShardingSpec(
dim=0,
- placements=[
- f"rank:{r}/cuda:{r}" for r in range(dist.get_world_size())
- ],
+ placements=[f"rank:{r}/cuda:{r}" for r in range(dist.get_world_size())],
)
@with_comms(init_rpc=False)
diff --git a/test/distributed/checkpoint/test_dedup_tensors.py b/test/distributed/checkpoint/test_dedup_tensors.py
index 6f2b81c..37525f6 100644
--- a/test/distributed/checkpoint/test_dedup_tensors.py
+++ b/test/distributed/checkpoint/test_dedup_tensors.py
@@ -1,12 +1,11 @@
# Owner(s): ["oncall: distributed"]
import dataclasses
+
import torch
from torch.distributed.checkpoint._dedup_tensors import dedup_tensors
from torch.distributed.checkpoint.planner import SavePlan, WriteItemType
-from torch.distributed.checkpoint.planner_helpers import (
- _create_write_item_for_tensor,
-)
+from torch.distributed.checkpoint.planner_helpers import _create_write_item_for_tensor
from torch.testing._internal.common_utils import run_tests, TestCase
@@ -33,9 +32,7 @@
self.assertEqual(2, len(dedup_plans[0].items))
self.assertEqual(1, len(dedup_plans[1].items))
- self.assertIn(
- "tensor_0", (item.index.fqn for item in dedup_plans[0].items)
- )
+ self.assertIn("tensor_0", (item.index.fqn for item in dedup_plans[0].items))
self.assertIn("r0", (item.index.fqn for item in dedup_plans[0].items))
self.assertIn("r1", (item.index.fqn for item in dedup_plans[1].items))
diff --git a/test/distributed/checkpoint/test_dtensor_checkpoint.py b/test/distributed/checkpoint/test_dtensor_checkpoint.py
index 5f664bd..8c4a1ff 100644
--- a/test/distributed/checkpoint/test_dtensor_checkpoint.py
+++ b/test/distributed/checkpoint/test_dtensor_checkpoint.py
@@ -6,19 +6,19 @@
import torch.distributed.checkpoint as dist_cp
from torch.distributed._tensor import (
DeviceMesh,
+ distribute_tensor,
DTensor,
Replicate,
Shard,
- distribute_tensor,
zeros,
)
-from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
+from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
skip_if_lt_x_gpu,
with_comms,
)
-from torch.testing._internal.common_utils import run_tests
+from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
SUBMESH_TENSOR_SIZE = 6
@@ -81,9 +81,7 @@
device_type=self.device_type,
mesh=range(dist.get_world_size()),
)
- sharded_dt = distribute_tensor(
- tensor_to_shard, mesh, placements=[Shard(0)]
- )
+ sharded_dt = distribute_tensor(tensor_to_shard, mesh, placements=[Shard(0)])
replicated_dt = distribute_tensor(
tensor_to_replicate, mesh, placements=[Replicate()]
)
@@ -179,9 +177,7 @@
storage_writer=dist_cp.FileSystemWriter(path=CHECKPOINT_DIR),
planner=dist_cp.DefaultSavePlanner(),
)
- model, _, _ = self.create_dtensor_model(
- local_tensor * 10, local_tensor_2 * 10
- )
+ model, _, _ = self.create_dtensor_model(local_tensor * 10, local_tensor_2 * 10)
state_dict = model.state_dict()
"""
@@ -247,9 +243,7 @@
if k == "submesh_sdt":
if self.rank % 2 == 0:
shard_size = int(SUBMESH_TENSOR_SIZE / v.device_mesh.size())
- self.assertEqual(
- v.to_local().size(), torch.Size([shard_size])
- )
+ self.assertEqual(v.to_local().size(), torch.Size([shard_size]))
self.assertEqual(v.to_local(), torch.zeros([shard_size]))
else:
self.assertEqual(v.to_local().size(), torch.Size([0]))
@@ -258,9 +252,7 @@
if k == "submesh_rdt":
if self.rank % 2 == 0:
shard_size = SUBMESH_TENSOR_SIZE
- self.assertEqual(
- v.to_local().size(), torch.Size([shard_size])
- )
+ self.assertEqual(v.to_local().size(), torch.Size([shard_size]))
self.assertEqual(v.to_local(), torch.zeros([shard_size]))
else:
self.assertEqual(v.to_local().size(), torch.Size([0]))
diff --git a/test/distributed/checkpoint/test_file_system_checkpoint.py b/test/distributed/checkpoint/test_file_system_checkpoint.py
index 3d92e79..d33b60a 100644
--- a/test/distributed/checkpoint/test_file_system_checkpoint.py
+++ b/test/distributed/checkpoint/test_file_system_checkpoint.py
@@ -1,42 +1,21 @@
# Owner(s): ["oncall: distributed"]
import os
-import sys
import shutil
+import sys
import tempfile
from typing import Dict
import torch
import torch.distributed as dist
from torch.distributed._shard import sharded_tensor
-from torch.distributed._shard.sharded_tensor import (
- ShardedTensor,
- state_dict_hook,
-)
+from torch.distributed._shard.sharded_tensor import ShardedTensor, state_dict_hook
from torch.distributed._shard.sharding_spec import (
ChunkShardingSpec,
EnumerableShardingSpec,
ShardingSpec,
ShardMetadata,
)
-from torch.testing._internal.common_distributed import (
- requires_nccl,
- skip_if_lt_x_gpu,
-)
-from torch.testing._internal.common_utils import TestCase
-from torch.testing._internal.distributed._shard.sharded_tensor import (
- ShardedTensorTestBase,
- with_comms,
-)
-from torch.testing._internal.distributed._shard.sharded_tensor._test_st_common import (
- MyShardedModel1,
-)
-
-
-from torch.testing._internal.common_utils import (
- TEST_WITH_DEV_DBG_ASAN,
- run_tests,
-)
from torch.distributed.checkpoint import (
FileSystemReader,
@@ -44,6 +23,20 @@
load_state_dict,
save_state_dict,
)
+from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
+
+from torch.testing._internal.common_utils import (
+ run_tests,
+ TEST_WITH_DEV_DBG_ASAN,
+ TestCase,
+)
+from torch.testing._internal.distributed._shard.sharded_tensor import (
+ ShardedTensorTestBase,
+ with_comms,
+)
+from torch.testing._internal.distributed._shard.sharded_tensor._test_st_common import (
+ MyShardedModel1,
+)
if TEST_WITH_DEV_DBG_ASAN:
@@ -122,9 +115,7 @@
state_dict_to_load_to = MyTestModule().state_dict()
with self.assertRaises(AssertionError):
- assert_state_dict_equal(
- self, state_dict_to_load_to, state_dict_to_save
- )
+ assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
# Load from file without any resharding
fs_reader = FileSystemReader(path=path)
@@ -134,9 +125,7 @@
no_dist=True,
)
- assert_state_dict_equal(
- self, state_dict_to_load_to, state_dict_to_save
- )
+ assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
with tempfile.TemporaryDirectory() as path:
state_dict_to_save = MyTestModule().state_dict()
@@ -151,9 +140,7 @@
state_dict_to_load_to = MyTestModule().state_dict()
with self.assertRaises(AssertionError):
- assert_state_dict_equal(
- self, state_dict_to_load_to, state_dict_to_save
- )
+ assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
# Load from file without any resharding
fs_reader = FileSystemReader(path=path)
@@ -163,9 +150,7 @@
no_dist=True,
)
- assert_state_dict_equal(
- self, state_dict_to_load_to, state_dict_to_save
- )
+ assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase):
@@ -212,15 +197,11 @@
dist.barrier()
with self.assertRaises(AssertionError):
- assert_state_dict_equal(
- self, state_dict_to_load_to, state_dict_to_save
- )
+ assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
# Test load.
fs_reader = FileSystemReader(path=path)
- load_state_dict(
- state_dict=state_dict_to_load_to, storage_reader=fs_reader
- )
+ load_state_dict(state_dict=state_dict_to_load_to, storage_reader=fs_reader)
assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
dist.barrier()
@@ -238,9 +219,7 @@
def load_tensor(self, tensor: ShardedTensor) -> torch.Tensor:
res = (
- torch.zeros(tensor.shape, device="cuda:0")
- if dist.get_rank() == 0
- else None
+ torch.zeros(tensor.shape, device="cuda:0") if dist.get_rank() == 0 else None
)
tensor.gather(out=res)
return res
@@ -335,9 +314,7 @@
state_dict_to_save = model_to_save.state_dict()
fs_writer = FileSystemWriter(path=path)
- save_state_dict(
- state_dict=state_dict_to_save, storage_writer=fs_writer
- )
+ save_state_dict(state_dict=state_dict_to_save, storage_writer=fs_writer)
dist.barrier()
@@ -404,9 +381,7 @@
fs_reader = FileSystemReader(path=path)
- load_state_dict(
- state_dict=state_dict_to_load_to, storage_reader=fs_reader
- )
+ load_state_dict(state_dict=state_dict_to_load_to, storage_reader=fs_reader)
# We can't use torch.allclose since each ST has a different sharding spec
store_tensor = self.load_tensor(model_to_save.sharded_tensor)
@@ -516,9 +491,7 @@
f"save-spec {save_spec} load-spec {load_spec}",
)
self.assertTrue(
- torch.allclose(
- save_dict["replicated"], load_dict_replicated
- ),
+ torch.allclose(save_dict["replicated"], load_dict_replicated),
f"save-spec {save_spec} load-spec {load_spec}",
)
diff --git a/test/distributed/checkpoint/test_fsdp_model_state.py b/test/distributed/checkpoint/test_fsdp_model_state.py
index 6cacf7c..9ccf36c 100644
--- a/test/distributed/checkpoint/test_fsdp_model_state.py
+++ b/test/distributed/checkpoint/test_fsdp_model_state.py
@@ -1,23 +1,23 @@
# Owner(s): ["oncall: distributed"]
import torch
+import torch.distributed as dist
+import torch.distributed.checkpoint as dist_cp
+
+from torch.distributed.checkpoint.default_planner import (
+ DefaultLoadPlanner,
+ DefaultSavePlanner,
+)
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
-import torch.distributed.checkpoint as dist_cp
-import torch.distributed as dist
-
-from torch.distributed.checkpoint.default_planner import (
- DefaultSavePlanner,
- DefaultLoadPlanner,
-)
+from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
+from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed._tensor.common_dtensor import (
DTensorTestBase,
with_comms,
)
-from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
-from torch.testing._internal.common_utils import run_tests
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
diff --git a/test/distributed/checkpoint/test_fsspec.py b/test/distributed/checkpoint/test_fsspec.py
index 14a4de9..b5d4195 100644
--- a/test/distributed/checkpoint/test_fsspec.py
+++ b/test/distributed/checkpoint/test_fsspec.py
@@ -9,23 +9,12 @@
import torch.distributed as dist
import torch.distributed.checkpoint as dcp
import torch.nn as nn
-from torch.distributed.checkpoint._fsspec_filesystem import (
- FsspecReader,
- FsspecWriter,
-)
-from torch.distributed.checkpoint.optimizer import (
- load_sharded_optimizer_state_dict,
-)
+from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter
+from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
-from torch.testing._internal.common_distributed import (
- requires_nccl,
- skip_if_lt_x_gpu,
-)
-from torch.testing._internal.common_utils import (
- run_tests,
- TestCase,
-)
+from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
+from torch.testing._internal.common_utils import run_tests, TestCase
from torch.testing._internal.distributed._shard.sharded_tensor import (
ShardedTensorTestBase,
with_comms,
@@ -182,9 +171,7 @@
return list(iter(opt.state.values()))[idx]
# Adam lazily creates its state
- self.assertEqual(
- opt_at(optim, 0)["exp_avg"], opt_at(optim_2, 0)["exp_avg"]
- )
+ self.assertEqual(opt_at(optim, 0)["exp_avg"], opt_at(optim_2, 0)["exp_avg"])
self.assertEqual(
opt_at(optim, 0)["exp_avg_sq"], opt_at(optim_2, 0)["exp_avg_sq"]
)
diff --git a/test/distributed/checkpoint/test_nested_dict.py b/test/distributed/checkpoint/test_nested_dict.py
index 115982e..12df302 100644
--- a/test/distributed/checkpoint/test_nested_dict.py
+++ b/test/distributed/checkpoint/test_nested_dict.py
@@ -1,11 +1,11 @@
# Owner(s): ["oncall: distributed"]
import torch
-from torch.testing._internal.common_utils import run_tests, TestCase
from torch.distributed.checkpoint._nested_dict import (
flatten_state_dict,
unflatten_state_dict,
)
+from torch.testing._internal.common_utils import run_tests, TestCase
class TestFlattening(TestCase):
diff --git a/test/distributed/checkpoint/test_planner.py b/test/distributed/checkpoint/test_planner.py
index a8563c7..53c129d 100644
--- a/test/distributed/checkpoint/test_planner.py
+++ b/test/distributed/checkpoint/test_planner.py
@@ -3,43 +3,45 @@
import sys
import torch
-from torch.distributed.checkpoint.planner import LoadItemType, WriteItemType
from torch.distributed._shard.sharded_tensor import (
Shard,
- ShardMetadata,
ShardedTensor,
ShardedTensorMetadata,
+ ShardMetadata,
)
from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
+from torch.distributed.checkpoint._dedup_tensors import dedup_tensors
-from torch.testing._internal.common_utils import (
- TestCase,
- TEST_WITH_DEV_DBG_ASAN,
- run_tests,
+from torch.distributed.checkpoint.default_planner import (
+ _create_default_local_metadata,
+ create_default_global_save_plan,
+ create_default_local_load_plan,
+ create_default_local_save_plan,
)
from torch.distributed.checkpoint.metadata import (
BytesStorageMetadata,
+ ChunkStorageMetadata,
MetadataIndex,
TensorStorageMetadata,
- ChunkStorageMetadata,
+)
+from torch.distributed.checkpoint.planner import LoadItemType, WriteItemType
+
+from torch.distributed.checkpoint.planner_helpers import (
+ create_read_items_for_chunk_list,
+)
+
+from torch.testing._internal.common_utils import (
+ run_tests,
+ TEST_WITH_DEV_DBG_ASAN,
+ TestCase,
)
from torch.testing._internal.distributed.distributed_utils import (
+ with_dist,
with_fake_comms,
- with_dist
)
-from torch.distributed.checkpoint.default_planner import (
- create_default_global_save_plan,
- create_default_local_save_plan,
- create_default_local_load_plan,
- _create_default_local_metadata
-)
-
-from torch.distributed.checkpoint.planner_helpers import create_read_items_for_chunk_list
-from torch.distributed.checkpoint._dedup_tensors import dedup_tensors
-
if TEST_WITH_DEV_DBG_ASAN:
print(
@@ -48,30 +50,34 @@
)
sys.exit(0)
+
def create_sharded_tensor(rank, world_size, shards_per_rank, shard_size=8):
shards_metadata = []
local_shards = []
for idx in range(0, world_size * shards_per_rank):
shard_rank = idx // shards_per_rank
- shard_md = ShardMetadata(shard_offsets=[idx * shard_size], shard_sizes=[shard_size], placement=f"rank:{shard_rank}/cpu")
+ shard_md = ShardMetadata(
+ shard_offsets=[idx * shard_size],
+ shard_sizes=[shard_size],
+ placement=f"rank:{shard_rank}/cpu",
+ )
shards_metadata.append(shard_md)
if shard_rank == rank:
shard = Shard.from_tensor_and_offsets(
torch.rand(*shard_md.shard_sizes),
shard_offsets=shard_md.shard_offsets,
- rank=rank
+ rank=rank,
)
local_shards.append(shard)
sharded_tensor_md = ShardedTensorMetadata(
shards_metadata=shards_metadata,
size=torch.Size([shard_size * len(shards_metadata)]),
- tensor_properties=TensorProperties.create_from_tensor(torch.zeros(1))
+ tensor_properties=TensorProperties.create_from_tensor(torch.zeros(1)),
)
return ShardedTensor._init_from_local_shards_and_global_metadata(
- local_shards=local_shards,
- sharded_tensor_metadata=sharded_tensor_md
+ local_shards=local_shards, sharded_tensor_metadata=sharded_tensor_md
)
@@ -81,18 +87,17 @@
tensor = torch.rand(10)
val = [1, 2, 3]
st = create_sharded_tensor(rank=1, world_size=4, shards_per_rank=1)
- state_dict = {
- "tensor": tensor,
- "value": val,
- "st": st
- }
+ state_dict = {"tensor": tensor, "value": val, "st": st}
plan = create_default_local_save_plan(state_dict, False)
self.assertEqual(2, len(plan.items))
wi = plan.items[0]
self.assertEqual(wi.index, MetadataIndex("tensor", [0]))
self.assertEqual(wi.type, WriteItemType.TENSOR)
self.assertEqual(wi.tensor_data.size, tensor.size())
- self.assertEqual(wi.tensor_data.properties, TensorProperties.create_from_tensor(torch.zeros(1)))
+ self.assertEqual(
+ wi.tensor_data.properties,
+ TensorProperties.create_from_tensor(torch.zeros(1)),
+ )
self.assertEqual(wi.tensor_data.chunk.offsets, torch.Size([0]))
self.assertEqual(wi.tensor_data.chunk.sizes, torch.Size([10]))
@@ -100,7 +105,10 @@
self.assertEqual(st_wi.index, MetadataIndex("st", [8]))
self.assertEqual(st_wi.type, WriteItemType.SHARD)
self.assertEqual(st_wi.tensor_data.size, st.size())
- self.assertEqual(st_wi.tensor_data.properties, TensorProperties.create_from_tensor(torch.zeros(1)))
+ self.assertEqual(
+ st_wi.tensor_data.properties,
+ TensorProperties.create_from_tensor(torch.zeros(1)),
+ )
self.assertEqual(st_wi.tensor_data.chunk.offsets, torch.Size([8]))
self.assertEqual(st_wi.tensor_data.chunk.sizes, torch.Size([8]))
@@ -111,7 +119,10 @@
tensor_wi = next(wi for wi in plan.items if wi.type == WriteItemType.TENSOR)
self.assertEqual(tensor_wi.index, MetadataIndex("tensor", [0]))
self.assertEqual(tensor_wi.tensor_data.size, tensor.size())
- self.assertEqual(tensor_wi.tensor_data.properties, TensorProperties.create_from_tensor(tensor))
+ self.assertEqual(
+ tensor_wi.tensor_data.properties,
+ TensorProperties.create_from_tensor(tensor),
+ )
self.assertEqual(tensor_wi.tensor_data.chunk.offsets, torch.Size([0]))
self.assertEqual(tensor_wi.tensor_data.chunk.sizes, torch.Size([10]))
@@ -125,11 +136,7 @@
tensor = torch.rand(10)
val = [1, 2, 3]
st = create_sharded_tensor(rank=rank, world_size=4, shards_per_rank=1)
- state_dict = {
- "tensor": tensor,
- "value": val,
- "st": st
- }
+ state_dict = {"tensor": tensor, "value": val, "st": st}
return create_default_local_save_plan(state_dict, rank == 0)
all_plans = [create_data(0), create_data(1), create_data(2), create_data(3)]
@@ -150,11 +157,15 @@
else:
self.assertTrue(isinstance(item_md, TensorStorageMetadata))
self.assertEqual(item_md.size, old_item.tensor_data.size)
- self.assertEqual(item_md.properties, old_item.tensor_data.properties)
+ self.assertEqual(
+ item_md.properties, old_item.tensor_data.properties
+ )
self.assertIsNotNone(new_item.index.index)
# Make sure the hint is correct
- self.assertEqual(item_md.chunks[new_item.index.index], old_item.tensor_data.chunk)
+ self.assertEqual(
+ item_md.chunks[new_item.index.index], old_item.tensor_data.chunk
+ )
def test_local_load_plan(self):
def create_state_dict(rank):
@@ -162,11 +173,7 @@
tensor = torch.rand(10)
val = [1, 2, 3]
st = create_sharded_tensor(rank=rank, world_size=4, shards_per_rank=1)
- return {
- "tensor": tensor,
- "value": val,
- "st": st
- }
+ return {"tensor": tensor, "value": val, "st": st}
state_dict = create_state_dict(1)
metadata = _create_default_local_metadata(state_dict)
@@ -175,7 +182,9 @@
# This will create 3 entries
self.assertEqual(3, len(load_plan.items))
st_item = next(ri for ri in load_plan.items if ri.dest_index.fqn == "st")
- tensor_item = next(ri for ri in load_plan.items if ri.dest_index.fqn == "tensor")
+ tensor_item = next(
+ ri for ri in load_plan.items if ri.dest_index.fqn == "tensor"
+ )
bytes_item = next(ri for ri in load_plan.items if ri.dest_index.fqn == "value")
self.assertEqual(st_item.type, LoadItemType.TENSOR)
@@ -208,7 +217,6 @@
)
}
-
# Rank 1 has a 16 bytes shard from [16, 32[
world8_state_dict = create_state_dict(rank=1, world_size=8)
world8_metadata = _create_default_local_metadata(world8_state_dict)
@@ -221,8 +229,12 @@
# Each 4-world shard has 32 elements, so it needs to load 2 shards
load_plan = create_default_local_load_plan(world4_state_dict, world8_metadata)
self.assertEqual(2, len(load_plan.items))
- low_ri = next(ri for ri in load_plan.items if ri.dest_offsets == torch.Size([0]))
- high_ri = next(ri for ri in load_plan.items if ri.dest_offsets == torch.Size([16]))
+ low_ri = next(
+ ri for ri in load_plan.items if ri.dest_offsets == torch.Size([0])
+ )
+ high_ri = next(
+ ri for ri in load_plan.items if ri.dest_offsets == torch.Size([16])
+ )
self.assertEqual(low_ri.storage_index, MetadataIndex("st", [32]))
self.assertEqual(low_ri.storage_offsets, torch.Size([0]))
@@ -258,6 +270,7 @@
shard_size=120 // world_size,
)
}
+
# rank 1 has a 30 bytes shard from [30, 60[
world4_state_dict = create_state_dict(rank=1, world_size=4)
world4_metadata = _create_default_local_metadata(world4_state_dict)
@@ -268,9 +281,13 @@
load_plan = create_default_local_load_plan(world3_state_dict, world4_metadata)
self.assertEqual(2, len(load_plan.items))
# this is [30, 60] to load [40, 60]
- low_ri = next(ri for ri in load_plan.items if ri.dest_offsets == torch.Size([0]))
+ low_ri = next(
+ ri for ri in load_plan.items if ri.dest_offsets == torch.Size([0])
+ )
# this is [60, 90] to load [60, 80]
- high_ri = next(ri for ri in load_plan.items if ri.dest_offsets == torch.Size([20]))
+ high_ri = next(
+ ri for ri in load_plan.items if ri.dest_offsets == torch.Size([20])
+ )
self.assertEqual(low_ri.storage_index, MetadataIndex("st", [30]))
self.assertEqual(low_ri.storage_offsets, torch.Size([10]))
@@ -284,27 +301,19 @@
self.assertEqual(high_ri.dest_offsets, torch.Size([20]))
self.assertEqual(high_ri.lengths, torch.Size([20]))
+
class TestPlannerHelpers(TestCase):
def test_create_read_item_from_chunks(self):
tensor_md = TensorStorageMetadata(
properties=TensorProperties.create_from_tensor(torch.empty([16])),
size=torch.Size([16]),
chunks=[
- ChunkStorageMetadata(
- offsets=torch.Size([0]),
- sizes=torch.Size([8])
- ),
- ChunkStorageMetadata(
- offsets=torch.Size([8]),
- sizes=torch.Size([8])
- )
- ]
+ ChunkStorageMetadata(offsets=torch.Size([0]), sizes=torch.Size([8])),
+ ChunkStorageMetadata(offsets=torch.Size([8]), sizes=torch.Size([8])),
+ ],
)
- chunk = ChunkStorageMetadata(
- offsets=torch.Size([4]),
- sizes=torch.Size([7])
- )
+ chunk = ChunkStorageMetadata(offsets=torch.Size([4]), sizes=torch.Size([7]))
read_items = create_read_items_for_chunk_list("foo", tensor_md, [chunk])
self.assertEqual(2, len(read_items))
@@ -316,7 +325,6 @@
self.assertEqual(torch.Size([4]), read_items[0].lengths)
-
self.assertEqual(MetadataIndex("foo", [4]), read_items[1].dest_index)
self.assertEqual(torch.Size([4]), read_items[1].dest_offsets)
@@ -325,5 +333,6 @@
self.assertEqual(torch.Size([3]), read_items[1].lengths)
+
if __name__ == "__main__":
run_tests()
diff --git a/test/distributed/checkpoint/test_traverse.py b/test/distributed/checkpoint/test_traverse.py
index 3a47311..4755967 100644
--- a/test/distributed/checkpoint/test_traverse.py
+++ b/test/distributed/checkpoint/test_traverse.py
@@ -1,6 +1,7 @@
# Owner(s): ["oncall: distributed"]
from collections import OrderedDict
+
import torch
import torch.distributed.checkpoint._traverse as _traverse
@@ -95,9 +96,7 @@
self.assertEqual(data[("key0", "key2")], torch.tensor([1]))
def test_traverse_doesnt_ignore_intermediate_collections(self) -> None:
- state_dict: STATE_DICT_TYPE = {
- "key0": [{"key1": {"key2": torch.tensor([1])}}]
- }
+ state_dict: STATE_DICT_TYPE = {"key0": [{"key1": {"key2": torch.tensor([1])}}]}
data = {}
diff --git a/test/distributed/checkpoint/test_utils.py b/test/distributed/checkpoint/test_utils.py
index e2b4aac..78d97f0 100644
--- a/test/distributed/checkpoint/test_utils.py
+++ b/test/distributed/checkpoint/test_utils.py
@@ -6,22 +6,20 @@
from torch.distributed._shard.sharded_tensor import (
Shard,
- ShardMetadata,
ShardedTensor,
ShardedTensorMetadata,
+ ShardMetadata,
)
from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
+from torch.distributed.checkpoint.metadata import MetadataIndex
+from torch.distributed.checkpoint.utils import find_state_dict_object
from torch.testing._internal.common_utils import (
- TestCase,
- TEST_WITH_DEV_DBG_ASAN,
run_tests,
+ TEST_WITH_DEV_DBG_ASAN,
+ TestCase,
)
-from torch.distributed.checkpoint.utils import find_state_dict_object
-from torch.distributed.checkpoint.metadata import MetadataIndex
-from torch.testing._internal.distributed.distributed_utils import (
- with_fake_comms
-)
+from torch.testing._internal.distributed.distributed_utils import with_fake_comms
if TEST_WITH_DEV_DBG_ASAN:
print(
@@ -30,30 +28,32 @@
)
sys.exit(0)
+
def create_sharded_tensor(rank, world_size, shards_per_rank):
shards_metadata = []
local_shards = []
for idx in range(0, world_size * shards_per_rank):
shard_rank = idx // shards_per_rank
- shard_md = ShardMetadata(shard_offsets=[idx * 8], shard_sizes=[8], placement=f"rank:{shard_rank}/cpu")
+ shard_md = ShardMetadata(
+ shard_offsets=[idx * 8], shard_sizes=[8], placement=f"rank:{shard_rank}/cpu"
+ )
shards_metadata.append(shard_md)
if shard_rank == rank:
shard = Shard.from_tensor_and_offsets(
torch.rand(*shard_md.shard_sizes),
shard_offsets=shard_md.shard_offsets,
- rank=rank
+ rank=rank,
)
local_shards.append(shard)
sharded_tensor_md = ShardedTensorMetadata(
shards_metadata=shards_metadata,
size=torch.Size([8 * len(shards_metadata)]),
- tensor_properties=TensorProperties.create_from_tensor(torch.zeros(1))
+ tensor_properties=TensorProperties.create_from_tensor(torch.zeros(1)),
)
return ShardedTensor._init_from_local_shards_and_global_metadata(
- local_shards=local_shards,
- sharded_tensor_metadata=sharded_tensor_md
+ local_shards=local_shards, sharded_tensor_metadata=sharded_tensor_md
)
diff --git a/torch/distributed/checkpoint/__init__.py b/torch/distributed/checkpoint/__init__.py
index c68747a..b45d0c2 100644
--- a/torch/distributed/checkpoint/__init__.py
+++ b/torch/distributed/checkpoint/__init__.py
@@ -1,23 +1,15 @@
+from .api import CheckpointException
+from .checkpointer import Checkpointer
+from .default_planner import DefaultLoadPlanner, DefaultSavePlanner
+from .filesystem import FileSystemCheckpointer, FileSystemReader, FileSystemWriter
from .metadata import (
- TensorStorageMetadata,
BytesStorageMetadata,
ChunkStorageMetadata,
Metadata,
+ TensorStorageMetadata,
)
-from .state_dict_loader import load_state_dict, load
-from .state_dict_saver import save_state_dict, save
-from .storage import StorageReader, StorageWriter
-from .checkpointer import Checkpointer
-from .filesystem import FileSystemReader, FileSystemWriter, FileSystemCheckpointer
-from .api import CheckpointException
-
-from .planner import (
- SavePlanner,
- LoadPlanner,
- SavePlan,
- LoadPlan,
- ReadItem,
- WriteItem,
-)
-from .default_planner import DefaultSavePlanner, DefaultLoadPlanner
from .optimizer import load_sharded_optimizer_state_dict
+from .planner import LoadPlan, LoadPlanner, ReadItem, SavePlan, SavePlanner, WriteItem
+from .state_dict_loader import load, load_state_dict
+from .state_dict_saver import save, save_state_dict
+from .storage import StorageReader, StorageWriter
diff --git a/torch/distributed/checkpoint/_dedup_tensors.py b/torch/distributed/checkpoint/_dedup_tensors.py
index cf7ca01..a6be92c 100644
--- a/torch/distributed/checkpoint/_dedup_tensors.py
+++ b/torch/distributed/checkpoint/_dedup_tensors.py
@@ -23,8 +23,10 @@
logger.propagate = False
return logger
+
logger = init_logger()
+
# TODO add docstring for dedup_tensors
def dedup_tensors(all_plans: List[SavePlan]) -> List[SavePlan]:
all_plans = list(all_plans)
@@ -51,8 +53,6 @@
for write_item in all_plans[plan_idx].items
if write_item.index not in key_set
]
- all_plans[plan_idx] = dataclasses.replace(
- all_plans[plan_idx], items=new_items
- )
+ all_plans[plan_idx] = dataclasses.replace(all_plans[plan_idx], items=new_items)
return all_plans
diff --git a/torch/distributed/checkpoint/_fsspec_filesystem.py b/torch/distributed/checkpoint/_fsspec_filesystem.py
index 8703d1b..5c98845 100644
--- a/torch/distributed/checkpoint/_fsspec_filesystem.py
+++ b/torch/distributed/checkpoint/_fsspec_filesystem.py
@@ -9,20 +9,18 @@
import queue
import threading
from abc import ABC, abstractmethod
-
from dataclasses import dataclass
from typing import Callable, cast, Dict, List, Optional, Union
import fsspec
-import torch
from fsspec import AbstractFileSystem
from fsspec.core import url_to_fs
+
+import torch
from torch import Tensor
from torch._utils import _get_device_module
-
from torch.distributed._shard._utils import narrow_tensor_by_index
from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex
-
from torch.distributed.checkpoint.planner import (
LoadItemType,
LoadPlan,
@@ -145,10 +143,7 @@
def _refill(self):
with self.device_module.stream(self.stream):
- while (
- not self._done
- and self.in_flight_data < self.inflight_threshhold
- ):
+ while not self._done and self.in_flight_data < self.inflight_threshhold:
_, obj = self.items[self.idx]
self.idx += 1
tensor = self.resolve_fun(obj).detach()
@@ -206,9 +201,7 @@
return size * torch._utils._element_size(dtype)
-def _split_by_size_and_type(
- bins: int, items: List[WriteItem]
-) -> List[List[WriteItem]]:
+def _split_by_size_and_type(bins: int, items: List[WriteItem]) -> List[List[WriteItem]]:
if bins == 1:
return [items]
@@ -276,16 +269,12 @@
planner.resolve_data,
)
- tensor_w = [
- wi for wi in write_items if wi.type != WriteItemType.BYTE_IO
- ]
+ tensor_w = [wi for wi in write_items if wi.type != WriteItemType.BYTE_IO]
for write_item in tensor_w:
loader.add(_item_size(write_item), write_item)
loader.start_loading()
- bytes_w = [
- wi for wi in write_items if wi.type == WriteItemType.BYTE_IO
- ]
+ bytes_w = [wi for wi in write_items if wi.type == WriteItemType.BYTE_IO]
write_results = []
with fs.transaction:
@@ -351,9 +340,7 @@
self.fs.makedirs(self.path, exist_ok=True)
return plan
- def prepare_global_plan(
- self, global_plan: List[SavePlan]
- ) -> List[SavePlan]:
+ def prepare_global_plan(self, global_plan: List[SavePlan]) -> List[SavePlan]:
new_plans = [
dataclasses.replace(plan, storage_data=_StoragePrefix(f"__{i}_"))
for i, plan in enumerate(global_plan)
@@ -376,9 +363,7 @@
file_queue: queue.Queue = queue.Queue()
if self.single_file_per_rank:
- for bucket in _split_by_size_and_type(
- self.thread_count, plan.items
- ):
+ for bucket in _split_by_size_and_type(self.thread_count, plan.items):
file_name = gen_file()
file_path = os.path.join(self.path, file_name)
file_queue.put((file_path, file_name, bucket))
@@ -427,9 +412,7 @@
fut.set_result(res)
return fut
- def finish(
- self, metadata: Metadata, results: List[List[WriteResult]]
- ) -> None:
+ def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
storage_md = dict()
for wr_list in results:
storage_md.update({wr.index: wr.storage_data for wr in wr_list})
@@ -495,16 +478,12 @@
with fsspec.open(metadata_path, "rb") as metadata_file:
return pickle.load(metadata_file)
- def set_up_storage_reader(
- self, metadata: Metadata, is_coordinator: bool
- ) -> None:
+ def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
self.storage_data = metadata.storage_data
assert self.storage_data is not None
def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
return plan
- def prepare_global_plan(
- self, global_plan: List[LoadPlan]
- ) -> List[LoadPlan]:
+ def prepare_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]:
return global_plan
diff --git a/torch/distributed/checkpoint/_nested_dict.py b/torch/distributed/checkpoint/_nested_dict.py
index 6d2e0bc..527a67e 100644
--- a/torch/distributed/checkpoint/_nested_dict.py
+++ b/torch/distributed/checkpoint/_nested_dict.py
@@ -1,16 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
from typing import Dict, Tuple
-from torch.distributed.checkpoint.metadata import (
- STATE_DICT_TYPE,
-)
+from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
-from ._traverse import (
- traverse_state_dict,
- set_element,
- OBJ_PATH,
- STATE_DICT_ITEM,
-)
+from ._traverse import OBJ_PATH, set_element, STATE_DICT_ITEM, traverse_state_dict
"""
TODO:
diff --git a/torch/distributed/checkpoint/_sharded_tensor_utils.py b/torch/distributed/checkpoint/_sharded_tensor_utils.py
index 4fff382..582dfc0 100644
--- a/torch/distributed/checkpoint/_sharded_tensor_utils.py
+++ b/torch/distributed/checkpoint/_sharded_tensor_utils.py
@@ -3,29 +3,12 @@
import copy
import torch.distributed as dist
+from torch.distributed._shard.sharded_tensor import Shard, ShardedTensor, ShardMetadata
+from torch.distributed._shard.sharded_tensor.metadata import ShardedTensorMetadata
+from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
from torch.distributed.remote_device import _remote_device
-from torch.distributed.checkpoint.metadata import (
- STATE_DICT_TYPE,
-)
-from torch.distributed._shard.sharded_tensor import (
- Shard,
- ShardMetadata,
- ShardedTensor,
-)
-
-from torch.distributed._shard.sharded_tensor.metadata import (
- ShardedTensorMetadata,
-)
-
-
-from ._traverse import (
- OBJ_PATH,
- traverse_state_dict,
- set_element,
- STATE_DICT_ITEM,
-)
-
+from ._traverse import OBJ_PATH, set_element, STATE_DICT_ITEM, traverse_state_dict
from .utils import _element_wise_add, _normalize_device_info
@@ -62,9 +45,7 @@
return
if len(inner_st.local_shards()) != 1:
- raise ValueError(
- "Cannot handle inner tensor with more than 1 shard"
- )
+ raise ValueError("Cannot handle inner tensor with more than 1 shard")
inner_shard = inner_st.local_shards()[0]
local_shards = [
diff --git a/torch/distributed/checkpoint/_traverse.py b/torch/distributed/checkpoint/_traverse.py
index 1367dc2..604b5e1 100644
--- a/torch/distributed/checkpoint/_traverse.py
+++ b/torch/distributed/checkpoint/_traverse.py
@@ -1,8 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
-import torch
-
from typing import (
Callable,
+ cast,
Collection,
List,
Mapping,
@@ -11,13 +10,12 @@
Tuple,
TypeVar,
Union,
- cast,
)
-from torch.distributed.checkpoint.metadata import (
- STATE_DICT_TYPE,
-)
+
+import torch
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
from torch.distributed._tensor import DTensor
+from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
PATH_ITEM = Union[str, int]
OBJ_PATH = Tuple[PATH_ITEM, ...]
@@ -47,6 +45,7 @@
By default, all collections with at least one ``torch.Tensor`` element are traversed.
Visitor takes a path argument that is a tuple of the keys used to reach it.
"""
+
# a value is terminal if it has no other containers values inside it
def _is_terminal(value: STATE_DICT_ITEM) -> bool:
values: Collection[STATE_DICT_ITEM]
diff --git a/torch/distributed/checkpoint/api.py b/torch/distributed/checkpoint/api.py
index 8ccd740..8286851 100644
--- a/torch/distributed/checkpoint/api.py
+++ b/torch/distributed/checkpoint/api.py
@@ -1,5 +1,5 @@
-from typing import Dict, Tuple, Any
import traceback as tb
+from typing import Any, Dict, Tuple
WRAPPED_EXCEPTION = Tuple[BaseException, tb.StackSummary]
@@ -15,9 +15,7 @@
return False
if len(obj) != 2:
return False
- return isinstance(obj[0], BaseException) and isinstance(
- obj[1], tb.StackSummary
- )
+ return isinstance(obj[0], BaseException) and isinstance(obj[1], tb.StackSummary)
class CheckpointException(BaseException):
diff --git a/torch/distributed/checkpoint/default_planner.py b/torch/distributed/checkpoint/default_planner.py
index 9bf74ab..775d5b8 100644
--- a/torch/distributed/checkpoint/default_planner.py
+++ b/torch/distributed/checkpoint/default_planner.py
@@ -6,50 +6,42 @@
import operator
from collections import ChainMap
from functools import reduce
-from typing import List, Tuple, Dict, Any, Union, cast
+from typing import Any, cast, Dict, List, Tuple, Union
import torch
-
from torch.distributed._shard._utils import narrow_tensor_by_index
from torch.distributed._tensor import DTensor
-
-
-from torch.distributed.checkpoint.planner import (
- SavePlanner,
- LoadPlanner,
- SavePlan,
- LoadPlan,
- ReadItem,
- WriteItem,
- WriteItemType,
-)
-
-from torch.distributed.checkpoint.metadata import (
- BytesStorageMetadata,
- ChunkStorageMetadata,
- TensorStorageMetadata,
- MetadataIndex,
- Metadata,
- STATE_DICT_TYPE,
- STORAGE_TYPES,
-)
-
-from torch.distributed.checkpoint.planner_helpers import (
- _create_read_items,
- _create_write_items,
- _create_default_metadata_only_plan,
-)
-
+from torch.distributed.checkpoint._dedup_tensors import dedup_tensors
from torch.distributed.checkpoint._nested_dict import (
FLATTEN_MAPPING,
flatten_state_dict,
)
-from torch.distributed.checkpoint._sharded_tensor_utils import (
- _flatten_sharded_tensors,
-)
-from torch.distributed.checkpoint._dedup_tensors import dedup_tensors
-from torch.distributed.checkpoint.utils import find_state_dict_object
+from torch.distributed.checkpoint._sharded_tensor_utils import _flatten_sharded_tensors
from torch.distributed.checkpoint._traverse import set_element
+from torch.distributed.checkpoint.metadata import (
+ BytesStorageMetadata,
+ ChunkStorageMetadata,
+ Metadata,
+ MetadataIndex,
+ STATE_DICT_TYPE,
+ STORAGE_TYPES,
+ TensorStorageMetadata,
+)
+from torch.distributed.checkpoint.planner import (
+ LoadPlan,
+ LoadPlanner,
+ ReadItem,
+ SavePlan,
+ SavePlanner,
+ WriteItem,
+ WriteItemType,
+)
+from torch.distributed.checkpoint.planner_helpers import (
+ _create_default_metadata_only_plan,
+ _create_read_items,
+ _create_write_items,
+)
+from torch.distributed.checkpoint.utils import find_state_dict_object
logger: logging.Logger = logging.getLogger(__name__)
@@ -79,9 +71,7 @@
self.dedup_replicated_tensors = dedup_replicated_tensors
self.mappings = {}
- def set_up_planner(
- self, state_dict: STATE_DICT_TYPE, is_coordinator: bool
- ) -> None:
+ def set_up_planner(self, state_dict: STATE_DICT_TYPE, is_coordinator: bool) -> None:
if self.flatten_state_dict:
state_dict, self.mappings = flatten_state_dict(state_dict)
if self.flatten_sharded_tensors:
@@ -90,9 +80,7 @@
self.is_coordinator = is_coordinator
def create_local_plan(self) -> SavePlan:
- plan = create_default_local_save_plan(
- self.state_dict, self.is_coordinator
- )
+ plan = create_default_local_save_plan(self.state_dict, self.is_coordinator)
if self.flatten_state_dict:
plan = dataclasses.replace(plan, planner_data=self.mappings)
self.plan = plan
@@ -114,9 +102,7 @@
# )
planner_data_dict = [p.planner_data for p in global_plan]
merged_mappings = dict(ChainMap(*planner_data_dict))
- metadata = dataclasses.replace(
- metadata, planner_data=merged_mappings
- )
+ metadata = dataclasses.replace(metadata, planner_data=merged_mappings)
if not _validate_global_plan(global_plan, metadata):
raise ValueError("Failed to validate global plan")
@@ -130,9 +116,7 @@
self.plan = new_plan
return new_plan
- def resolve_data(
- self, write_item: WriteItem
- ) -> Union[torch.Tensor, io.BytesIO]:
+ def resolve_data(self, write_item: WriteItem) -> Union[torch.Tensor, io.BytesIO]:
object = self.lookup_object(write_item.index)
return self.transform_object(write_item, object)
@@ -222,9 +206,7 @@
def transform_tensor(self, read_item: ReadItem, tensor: torch.Tensor):
"""Extension from the planner interface to make it easy to extend the default planner."""
- return narrow_tensor_by_index(
- tensor, read_item.dest_offsets, read_item.lengths
- )
+ return narrow_tensor_by_index(tensor, read_item.dest_offsets, read_item.lengths)
def create_default_local_load_plan(
@@ -353,9 +335,7 @@
return md
-def _check_box_overlap(
- box0: ChunkStorageMetadata, box1: ChunkStorageMetadata
-) -> bool:
+def _check_box_overlap(box0: ChunkStorageMetadata, box1: ChunkStorageMetadata) -> bool:
"""Check if two boxes overlap. Tuples are (offset, lengths)."""
# For each dim of each shard, check if one shard resides on the other
# end of second shard with respect to that dim. As an example for a 2D
@@ -385,9 +365,7 @@
return True
-def _validate_global_plan(
- global_plan: List[SavePlan], metadata: Metadata
-) -> bool:
+def _validate_global_plan(global_plan: List[SavePlan], metadata: Metadata) -> bool:
all_good = True
for key, value in metadata.state_dict_metadata.items():
if isinstance(value, BytesStorageMetadata):
@@ -402,7 +380,10 @@
"""
key:%s has out of bounds chunk:
tensor-size:%s chunk: %s
- """, key, value.size, chunk0
+ """,
+ key,
+ value.size,
+ chunk0,
)
all_good = False
chunks_volume += reduce(operator.mul, chunk0.sizes, 1)
@@ -422,7 +403,10 @@
"""
key:%s invalid fill tensor-volume:
%s chunks-volume: %s
- """, key, tensor_volume, chunks_volume
+ """,
+ key,
+ tensor_volume,
+ chunks_volume,
)
all_good = False
diff --git a/torch/distributed/checkpoint/examples/fsdp_checkpoint_example.py b/torch/distributed/checkpoint/examples/fsdp_checkpoint_example.py
index 347f979..9e2438c 100644
--- a/torch/distributed/checkpoint/examples/fsdp_checkpoint_example.py
+++ b/torch/distributed/checkpoint/examples/fsdp_checkpoint_example.py
@@ -14,12 +14,10 @@
import torch.distributed as dist
import torch.distributed.checkpoint as dist_cp
import torch.multiprocessing as mp
+from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
-from torch.distributed.checkpoint.optimizer import (
- load_sharded_optimizer_state_dict,
-)
CHECKPOINT_DIR = f"/scratch/{os.environ['LOGNAME']}/checkpoint"
diff --git a/torch/distributed/checkpoint/filesystem.py b/torch/distributed/checkpoint/filesystem.py
index 6f7c710..8ee2a67 100644
--- a/torch/distributed/checkpoint/filesystem.py
+++ b/torch/distributed/checkpoint/filesystem.py
@@ -1,53 +1,38 @@
-from abc import ABC, abstractmethod
-import queue
-import threading
import collections
-
-from dataclasses import dataclass
-import os
import dataclasses
import io
+import os
import pickle
-from typing import Optional, List, Union, Dict, cast
+import queue
+import threading
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from pathlib import Path
+from typing import cast, Dict, List, Optional, Union
import torch
+import torch.distributed as dist
from torch import Tensor
+from torch._utils import _get_device_module
+from torch.distributed._shard._utils import narrow_tensor_by_index
+from torch.distributed.checkpoint.checkpointer import Checkpointer
from torch.futures import Future
-from pathlib import Path
-from .metadata import (
- Metadata,
- MetadataIndex,
-)
-from .storage import (
- StorageReader,
- StorageWriter,
- WriteResult,
-)
-
+from .metadata import Metadata, MetadataIndex
from .planner import (
LoadItemType,
- LoadPlanner,
LoadPlan,
+ LoadPlanner,
+ ReadItem,
SavePlan,
SavePlanner,
- ReadItem,
WriteItem,
WriteItemType,
)
-
+from .storage import StorageReader, StorageWriter, WriteResult
from .utils import _create_file_view
-import torch.distributed as dist
-from torch.distributed.checkpoint.checkpointer import Checkpointer
-from torch.distributed._shard._utils import narrow_tensor_by_index
-from torch._utils import _get_device_module
-
-__all__ = [
- "FileSystemWriter",
- "FileSystemReader",
- "FileSystemCheckpointer"
-]
+__all__ = ["FileSystemWriter", "FileSystemReader", "FileSystemCheckpointer"]
@dataclass
@@ -150,10 +135,7 @@
def _refill(self):
with self.device_module.stream(self.stream):
- while (
- not self._done
- and self.in_flight_data < self.inflight_threshhold
- ):
+ while not self._done and self.in_flight_data < self.inflight_threshhold:
_, obj = self.items[self.idx]
self.idx += 1
tensor = self.resolve_fun(obj).detach()
@@ -211,9 +193,7 @@
return size * torch._utils._element_size(dtype)
-def _split_by_size_and_type(
- bins, items: List[WriteItem]
-) -> List[List[WriteItem]]:
+def _split_by_size_and_type(bins, items: List[WriteItem]) -> List[List[WriteItem]]:
if bins == 1:
return [items]
@@ -276,16 +256,12 @@
planner.resolve_data,
)
- tensor_w = [
- wi for wi in write_items if wi.type != WriteItemType.BYTE_IO
- ]
+ tensor_w = [wi for wi in write_items if wi.type != WriteItemType.BYTE_IO]
for write_item in tensor_w:
loader.add(_item_size(write_item), write_item)
loader.start_loading()
- bytes_w = [
- wi for wi in write_items if wi.type == WriteItemType.BYTE_IO
- ]
+ bytes_w = [wi for wi in write_items if wi.type == WriteItemType.BYTE_IO]
write_results = []
with file_name.open("wb") as stream:
@@ -358,9 +334,7 @@
self.path.mkdir(parents=True, exist_ok=True)
return plan
- def prepare_global_plan(
- self, global_plan: List[SavePlan]
- ) -> List[SavePlan]:
+ def prepare_global_plan(self, global_plan: List[SavePlan]) -> List[SavePlan]:
new_plans = [
dataclasses.replace(plan, storage_data=_StoragePrefix(f"__{i}_"))
for i, plan in enumerate(global_plan)
@@ -383,9 +357,7 @@
file_queue: queue.Queue = queue.Queue()
if self.single_file_per_rank:
- for bucket in _split_by_size_and_type(
- self.thread_count, plan.items
- ):
+ for bucket in _split_by_size_and_type(self.thread_count, plan.items):
file_name = gen_file()
file_queue.put((self.path / file_name, file_name, bucket))
else:
@@ -432,9 +404,7 @@
fut.set_result(res)
return fut
- def finish(
- self, metadata: Metadata, results: List[List[WriteResult]]
- ) -> None:
+ def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
storage_md = dict()
for wr_list in results:
storage_md.update({wr.index: wr.storage_data for wr in wr_list})
@@ -507,11 +477,10 @@
def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
return plan
- def prepare_global_plan(
- self, global_plan: List[LoadPlan]
- ) -> List[LoadPlan]:
+ def prepare_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]:
return global_plan
+
class FileSystemCheckpointer(Checkpointer):
"""An implementation of :py:class:`torch.distributed.checkpoint.checkpointer.Checkpointer`
for the file system. Wraps the creation and usage of ``FileSystemWriter`` and ``FileSystemReader``.
@@ -547,11 +516,7 @@
"""
storage_writer = FileSystemWriter(
- path,
- single_file_per_rank,
- sync_files,
- thread_count,
- per_thread_copy_ahead
+ path, single_file_per_rank, sync_files, thread_count, per_thread_copy_ahead
)
storage_reader = FileSystemReader(path)
@@ -562,5 +527,5 @@
coordinator_rank=coordinator_rank,
no_dist=no_dist,
load_planner=load_planner,
- save_planner=save_planner
+ save_planner=save_planner,
)
diff --git a/torch/distributed/checkpoint/metadata.py b/torch/distributed/checkpoint/metadata.py
index 8477e46..4ce6250 100644
--- a/torch/distributed/checkpoint/metadata.py
+++ b/torch/distributed/checkpoint/metadata.py
@@ -1,12 +1,10 @@
from dataclasses import dataclass, field
-from typing import Dict, List, Union, Optional, Sequence, Any
-from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
-from torch.distributed.checkpoint.stateful import StatefulT
+from typing import Any, Dict, List, Optional, Sequence, Union
import torch
-from torch.distributed._shard.sharded_tensor import (
- ShardedTensor,
-)
+from torch.distributed._shard.sharded_tensor import ShardedTensor
+from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
+from torch.distributed.checkpoint.stateful import StatefulT
__all__ = [
"ChunkStorageMetadata",
diff --git a/torch/distributed/checkpoint/optimizer.py b/torch/distributed/checkpoint/optimizer.py
index 5e66f1e..7dd932b 100644
--- a/torch/distributed/checkpoint/optimizer.py
+++ b/torch/distributed/checkpoint/optimizer.py
@@ -1,49 +1,41 @@
# Copyright (c) Meta Platforms, Inc. and affiliates
import dataclasses
-from typing import Dict, List, Optional, Sequence, Tuple, Union, cast
-from torch.distributed.checkpoint.planner import LoadPlan
+from typing import cast, Dict, List, Optional, Sequence, Tuple, Union
import torch
import torch.distributed as dist
+from torch._utils import _get_device_module
from torch.distributed._shard.sharded_tensor.api import ShardedTensor
from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
from torch.distributed._shard.sharded_tensor.shard import Shard
-from torch.distributed._shard.sharding_spec.chunk_sharding_spec import (
- ChunkShardingSpec,
-)
-
-import torch.distributed.checkpoint as dist_cp
+from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec
+from torch.distributed._tensor import DTensor
+from torch.distributed.checkpoint._nested_dict import unflatten_state_dict
+from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner
from torch.distributed.checkpoint.metadata import (
BytesStorageMetadata,
+ ChunkStorageMetadata,
Metadata,
MetadataIndex,
STATE_DICT_TYPE,
TensorStorageMetadata,
- ChunkStorageMetadata,
)
-from torch.distributed.distributed_c10d import _get_default_group
-from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor
+from torch.distributed.checkpoint.planner import LoadPlan, LoadPlanner
from torch.distributed.checkpoint.planner_helpers import (
- create_read_items_for_chunk_list,
_create_read_items,
+ create_read_items_for_chunk_list,
)
-from torch.distributed.remote_device import _remote_device
-
-from torch.distributed._tensor import DTensor
-from torch.distributed.checkpoint.default_planner import (
- DefaultLoadPlanner,
-)
-from torch.distributed.checkpoint.planner import LoadPlanner
-
-from torch.distributed.checkpoint._nested_dict import unflatten_state_dict
+from torch.distributed.checkpoint.state_dict_loader import load_state_dict
+from torch.distributed.checkpoint.storage import StorageReader
from torch.distributed.checkpoint.utils import (
_element_wise_add,
_element_wise_sub,
- _normalize_device_info
+ _normalize_device_info,
)
-
-from torch._utils import _get_device_module
+from torch.distributed.distributed_c10d import _get_default_group
+from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor
+from torch.distributed.remote_device import _remote_device
STATE_DICT_2D_LAYOUT = Dict[str, Tuple[Optional[Sequence[int]], Sequence[int]]]
@@ -59,7 +51,9 @@
return "cpu"
device_module = _get_device_module(device_type)
if device_module.is_available():
- return _normalize_device_info(device_type, global_rank % device_module.device_count())
+ return _normalize_device_info(
+ device_type, global_rank % device_module.device_count()
+ )
return "cpu"
@@ -90,18 +84,17 @@
if type(val.local_shards()[0].tensor) is ShardedTensor:
return True
if type(val.local_shards()[0].tensor) is DTensor:
- raise ValueError(
- "Cannot handle DTensor nested insided ShardedTensor"
- )
+ raise ValueError("Cannot handle DTensor nested insided ShardedTensor")
elif type(val) is DTensor and (
- type(val._local_tensor) is DTensor
- or type(val._local_tensor) is ShardedTensor
+ type(val._local_tensor) is DTensor or type(val._local_tensor) is ShardedTensor
):
raise ValueError("Cannot handle nested DTensor")
return False
-def _alloc_tensor(props: TensorProperties, size: Sequence[int], device_type: str = "cuda") -> torch.Tensor:
+def _alloc_tensor(
+ props: TensorProperties, size: Sequence[int], device_type: str = "cuda"
+) -> torch.Tensor:
return torch.empty(
size=size,
dtype=props.dtype,
@@ -181,9 +174,7 @@
local_chunks = [
ChunkStorageMetadata(
offsets=torch.Size(
- _element_wise_add(
- original_shard.metadata.shard_offsets, offset
- )
+ _element_wise_add(original_shard.metadata.shard_offsets, offset)
),
sizes=torch.Size(original_shard.metadata.shard_sizes),
)
@@ -196,9 +187,7 @@
# TODO: we should change _create_sharded_read_items to have more ergonomic API
for ri in reqs:
assert ri.dest_index.offset is not None
- original_offset = _element_wise_sub(
- ri.dest_index.offset, offset
- )
+ original_offset = _element_wise_sub(ri.dest_index.offset, offset)
original_index = dataclasses.replace(
ri.dest_index, offset=torch.Size(original_offset)
)
@@ -214,7 +203,7 @@
def load_sharded_optimizer_state_dict(
model_state_dict: STATE_DICT_TYPE,
optimizer_key: str,
- storage_reader: dist_cp.StorageReader,
+ storage_reader: StorageReader,
planner: Optional[LoadPlanner] = None,
) -> STATE_DICT_TYPE:
"""
@@ -273,7 +262,9 @@
if dp_pg is None:
placements = []
for i in range(dist.get_world_size()):
- device_info = _normalize_device_info(dp_pg_device_type, i % device_module.device_count())
+ device_info = _normalize_device_info(
+ dp_pg_device_type, i % device_module.device_count()
+ )
placements.append(f"rank:{i}/{device_info}")
sharding_spec = ChunkShardingSpec(dim=0, placements=placements) # type: ignore[arg-type]
else:
@@ -294,7 +285,9 @@
# value: TensorStorageMetadata
if value.size.numel() == 1:
- state_dict[key] = _alloc_tensor(value.properties, value.size, dp_pg_device_type)
+ state_dict[key] = _alloc_tensor(
+ value.properties, value.size, dp_pg_device_type
+ )
elif dp_pg is None:
state_dict[key] = _create_chunk_sharded_tensor(
_alloc_tensor(value.properties, value.size, dp_pg_device_type),
@@ -313,10 +306,7 @@
local_shards = []
current_rank = dist.get_rank(dp_pg)
for shard_md in st_md.shards_metadata:
- if (
- cast(_remote_device, shard_md.placement).rank()
- != current_rank
- ):
+ if cast(_remote_device, shard_md.placement).rank() != current_rank:
continue
local_shards.append(
Shard(
@@ -331,18 +321,13 @@
local_shards, st_md, process_group=dp_pg
)
- if (
- spec_key in layout_specs
- and layout_specs[spec_key][0] is not None
- ):
- fqn_to_offset[key] = cast(
- Sequence[int], layout_specs[spec_key][0]
- )
+ if spec_key in layout_specs and layout_specs[spec_key][0] is not None:
+ fqn_to_offset[key] = cast(Sequence[int], layout_specs[spec_key][0])
state_dict[key] = st
# Whether we unflatten before or after doesn't matter
- dist_cp.load_state_dict(
+ load_state_dict(
state_dict=state_dict,
storage_reader=storage_reader,
# FIXME the type of planner is wrong in load_state_dict
diff --git a/torch/distributed/checkpoint/planner.py b/torch/distributed/checkpoint/planner.py
index b4aed46..ebcdfb0 100644
--- a/torch/distributed/checkpoint/planner.py
+++ b/torch/distributed/checkpoint/planner.py
@@ -1,19 +1,13 @@
import abc
-from dataclasses import dataclass
import io
-from typing import List, Tuple, Any, Union, Optional
+from dataclasses import dataclass
+from enum import auto, Enum
+from typing import Any, List, Optional, Tuple, Union
-from enum import Enum, auto
import torch
-
from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
-from .metadata import (
- ChunkStorageMetadata,
- MetadataIndex,
- Metadata,
- STATE_DICT_TYPE,
-)
+from .metadata import ChunkStorageMetadata, Metadata, MetadataIndex, STATE_DICT_TYPE
__all__ = [
@@ -223,9 +217,7 @@
pass
@abc.abstractmethod
- def resolve_data(
- self, write_item: WriteItem
- ) -> Union[torch.Tensor, io.BytesIO]:
+ def resolve_data(self, write_item: WriteItem) -> Union[torch.Tensor, io.BytesIO]:
"""
Transform and prepare ``write_item`` from ``state_dict`` for storage, ensuring idempotency and thread-safety.
diff --git a/torch/distributed/checkpoint/planner_helpers.py b/torch/distributed/checkpoint/planner_helpers.py
index e9cde82..c01ea0d 100644
--- a/torch/distributed/checkpoint/planner_helpers.py
+++ b/torch/distributed/checkpoint/planner_helpers.py
@@ -1,7 +1,6 @@
from typing import Any, List
import torch
-
from torch.distributed._shard.metadata import ShardMetadata
from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
@@ -16,7 +15,6 @@
STORAGE_TYPES,
TensorStorageMetadata,
)
-
from .planner import (
LoadItemType,
ReadItem,
@@ -25,7 +23,6 @@
WriteItem,
WriteItemType,
)
-
from .resharding import (
_check_shard_metadata_pair_overlap,
_shards_get_overlap_region_wrt_saved_tensor,
diff --git a/torch/distributed/checkpoint/resharding.py b/torch/distributed/checkpoint/resharding.py
index 8753bbc..1ebb0ba 100644
--- a/torch/distributed/checkpoint/resharding.py
+++ b/torch/distributed/checkpoint/resharding.py
@@ -1,12 +1,13 @@
from typing import List, Tuple
-from torch.distributed.checkpoint.metadata import (
- ChunkStorageMetadata
-)
+from torch.distributed.checkpoint.metadata import ChunkStorageMetadata
__all__: List[str] = []
-def _check_shard_metadata_pair_overlap(shard1: ChunkStorageMetadata, shard2: ChunkStorageMetadata):
+
+def _check_shard_metadata_pair_overlap(
+ shard1: ChunkStorageMetadata, shard2: ChunkStorageMetadata
+):
"""Check if two shards overlap."""
# For each dim of each shard, check if one shard resides on the other
# end of second shard with respect to that dim. As an example for a 2D
@@ -21,6 +22,7 @@
return True
+
def _shards_get_overlap_region_wrt_saved_tensor(
saved_shard: ChunkStorageMetadata, current_shard: ChunkStorageMetadata
) -> List[Tuple[int, int, int, int]]:
@@ -56,9 +58,7 @@
if saved_shard_offset > current_shard_offset:
offset_for_saved_tensor = 0
- offset_for_current_tensor = (
- saved_shard_offset - current_shard_offset
- )
+ offset_for_current_tensor = saved_shard_offset - current_shard_offset
else:
offset_for_saved_tensor = current_shard_offset - saved_shard_offset
offset_for_current_tensor = 0
diff --git a/torch/distributed/checkpoint/state_dict_loader.py b/torch/distributed/checkpoint/state_dict_loader.py
index a9291ae..57941e0 100644
--- a/torch/distributed/checkpoint/state_dict_loader.py
+++ b/torch/distributed/checkpoint/state_dict_loader.py
@@ -1,20 +1,18 @@
-from typing import Any, Dict, Optional
import warnings
+from typing import Any, Dict, Optional
import torch
import torch.distributed as dist
from torch.distributed.checkpoint.stateful import Stateful
-from .storage import (
- StorageReader,
-)
-from .planner import LoadPlanner
from .default_planner import DefaultLoadPlanner
-
-from .utils import _DistWrapper, _all_gather_keys
+from .planner import LoadPlanner
+from .storage import StorageReader
+from .utils import _all_gather_keys, _DistWrapper
__all__ = ["load_state_dict", "load"]
+
def load_state_dict(
state_dict: Dict[str, Any],
storage_reader: StorageReader,
@@ -28,7 +26,10 @@
"'load_state_dict' is deprecated and will be removed in future versions. Please use 'load' instead."
)
# TODO: test returning `load` here instead.
- return _load_state_dict(state_dict, storage_reader, process_group, coordinator_rank, no_dist, planner)
+ return _load_state_dict(
+ state_dict, storage_reader, process_group, coordinator_rank, no_dist, planner
+ )
+
def load(
state_dict: Dict[str, Any],
@@ -124,7 +125,9 @@
elem = state_dict[key]
statetful_sd[key] = elem.state_dict() if isinstance(elem, Stateful) else elem
- _load_state_dict(statetful_sd, storage_reader, process_group, coordinator_rank, no_dist, planner)
+ _load_state_dict(
+ statetful_sd, storage_reader, process_group, coordinator_rank, no_dist, planner
+ )
for key in keys:
if key not in state_dict:
continue
@@ -133,6 +136,7 @@
elem.load_state_dict(statetful_sd[key])
state_dict[key] = elem
+
def _load_state_dict(
state_dict: Dict[str, Any],
storage_reader: StorageReader,
@@ -141,7 +145,6 @@
no_dist: bool = False,
planner: Optional[LoadPlanner] = None,
) -> None:
-
torch._C._log_api_usage_once("torch.distributed.checkpoint.load_state_dict")
distW = _DistWrapper(process_group, not no_dist, coordinator_rank)
diff --git a/torch/distributed/checkpoint/state_dict_saver.py b/torch/distributed/checkpoint/state_dict_saver.py
index 2fba073..f672aa3 100644
--- a/torch/distributed/checkpoint/state_dict_saver.py
+++ b/torch/distributed/checkpoint/state_dict_saver.py
@@ -1,18 +1,14 @@
-from typing import Optional
import warnings
+from typing import Optional
import torch
import torch.distributed as dist
from torch.distributed.checkpoint.stateful import Stateful
-from .planner import SavePlanner
+
from .default_planner import DefaultSavePlanner
-
-
-from .storage import (
- StorageWriter,
-)
-
from .metadata import Metadata, STATE_DICT_TYPE
+from .planner import SavePlanner
+from .storage import StorageWriter
from .utils import _DistWrapper
__all__ = ["save_state_dict", "save"]
@@ -32,7 +28,10 @@
)
# TODO: test returning `save` here instead.
- return _save_state_dict(state_dict, storage_writer, process_group, coordinator_rank, no_dist, planner)
+ return _save_state_dict(
+ state_dict, storage_writer, process_group, coordinator_rank, no_dist, planner
+ )
+
def save(
state_dict: STATE_DICT_TYPE,
@@ -108,7 +107,9 @@
dumpable_state_dict = {}
for key, elem in state_dict.items():
- dumpable_state_dict[key] = elem.state_dict() if isinstance(elem, Stateful) else elem
+ dumpable_state_dict[key] = (
+ elem.state_dict() if isinstance(elem, Stateful) else elem
+ )
return _save_state_dict(
dumpable_state_dict,
@@ -116,9 +117,10 @@
process_group,
coordinator_rank,
no_dist,
- planner
+ planner,
)
+
def _save_state_dict(
state_dict: STATE_DICT_TYPE,
storage_writer: StorageWriter,
@@ -127,7 +129,6 @@
no_dist: bool = False,
planner: Optional[SavePlanner] = None,
) -> Metadata:
-
torch._C._log_api_usage_once("torch.distributed.checkpoint.save_state_dict")
distW = _DistWrapper(process_group, not no_dist, coordinator_rank)
@@ -149,9 +150,7 @@
nonlocal global_metatadata
assert planner is not None
- all_local_plans, global_metatadata = planner.create_global_plan(
- all_local_plans
- )
+ all_local_plans, global_metatadata = planner.create_global_plan(all_local_plans)
all_local_plans = storage_writer.prepare_global_plan(all_local_plans)
return all_local_plans
diff --git a/torch/distributed/checkpoint/storage.py b/torch/distributed/checkpoint/storage.py
index a15db14..59282df 100644
--- a/torch/distributed/checkpoint/storage.py
+++ b/torch/distributed/checkpoint/storage.py
@@ -1,20 +1,11 @@
import abc
from dataclasses import dataclass
-from typing import List, Any
+from typing import Any, List
from torch.futures import Future
-from .metadata import (
- Metadata,
- MetadataIndex,
-)
-
-from .planner import (
- LoadPlan,
- SavePlan,
- SavePlanner,
- LoadPlanner,
-)
+from .metadata import Metadata, MetadataIndex
+from .planner import LoadPlan, LoadPlanner, SavePlan, SavePlanner
__all__ = ["WriteResult", "StorageWriter", "StorageReader"]
@@ -115,9 +106,7 @@
pass
@abc.abstractmethod
- def finish(
- self, metadata: Metadata, results: List[List[WriteResult]]
- ) -> None:
+ def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
"""
Write the metadata and marks the current checkpoint as successful.
diff --git a/torch/distributed/checkpoint/utils.py b/torch/distributed/checkpoint/utils.py
index 25f96f1..b8dadf3 100644
--- a/torch/distributed/checkpoint/utils.py
+++ b/torch/distributed/checkpoint/utils.py
@@ -1,37 +1,21 @@
-import os
import io
import itertools
-from typing import (
- List,
- Callable,
- Optional,
- Union,
- TypeVar,
- Dict,
- Any,
- cast,
- Sequence
-)
-import torch.distributed as dist
-from .api import (
- CheckpointException,
- _wrap_exception,
- _is_wrapped_exception,
- WRAPPED_EXCEPTION,
-)
+import os
+from typing import Any, Callable, cast, Dict, List, Optional, Sequence, TypeVar, Union
import torch
-
-from torch.distributed._shard.sharded_tensor import (
- ShardedTensor,
-)
+import torch.distributed as dist
+from torch.distributed._shard.sharded_tensor import ShardedTensor
from torch.distributed._shard.sharded_tensor.shard import Shard
from torch.distributed._tensor import DTensor
-from .metadata import (
- STATE_DICT_TYPE,
- MetadataIndex,
+from .api import (
+ _is_wrapped_exception,
+ _wrap_exception,
+ CheckpointException,
+ WRAPPED_EXCEPTION,
)
+from .metadata import MetadataIndex, STATE_DICT_TYPE
__all__ = ["find_tensor_shard", "find_state_dict_object"]
@@ -47,6 +31,7 @@
{i: err for i, err in enumerate(results) if _is_wrapped_exception(err)},
)
+
def _all_gather_keys(local_dict: Dict[Any, Any]) -> List[Any]:
"""Gathers all keys, and returns them sorted."""
keys = list(local_dict.keys())
@@ -55,6 +40,7 @@
dist.all_gather_object(gathered_keys, keys)
return sorted(set(itertools.chain.from_iterable(gathered_keys)))
+
class _DistWrapper:
"""
This is a wrapper around PG that provides a series of features around object collectives.
@@ -123,9 +109,7 @@
def all_gather_object(self, object: T) -> List[T]:
"""Implement functionality similar to c10d::all_gather_object but without distributed enabled."""
if self.use_dist:
- gather_objs = cast(
- List[T], [None] * dist.get_world_size(self.group)
- )
+ gather_objs = cast(List[T], [None] * dist.get_world_size(self.group))
dist.all_gather_object(
object_list=gather_objs, obj=object, group=self.group
@@ -140,9 +124,7 @@
gather_result = cast(List[T], [None])
dist.scatter_object_list(
scatter_object_output_list=gather_result,
- scatter_object_input_list=object_list
- if self.is_coordinator
- else None,
+ scatter_object_input_list=object_list if self.is_coordinator else None,
src=self.coordinator_rank,
group=self.group,
)
@@ -282,9 +264,7 @@
try:
result = map_fun()
except BaseException as e:
- result = CheckpointException(
- step, {self.rank: _wrap_exception(e)}
- )
+ result = CheckpointException(step, {self.rank: _wrap_exception(e)})
final_result = self.broadcast_object(result)
if isinstance(final_result, CheckpointException):
raise final_result
@@ -302,22 +282,17 @@
if index.index is not None:
if (
len(shards) > index.index
- and torch.Size(shards[index.index].metadata.shard_offsets)
- == index.offset
+ and torch.Size(shards[index.index].metadata.shard_offsets) == index.offset
):
return shards[index.index]
for shard in shards:
if torch.Size(shard.metadata.shard_offsets) == index.offset:
return shard
- raise ValueError(
- f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'"
- )
+ raise ValueError(f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'")
-def find_tensor_shard(
- tensor: torch.Tensor, index: MetadataIndex
-) -> torch.Tensor:
+def find_tensor_shard(tensor: torch.Tensor, index: MetadataIndex) -> torch.Tensor:
if isinstance(tensor, DTensor):
return tensor.to_local()
if isinstance(tensor, ShardedTensor):
@@ -332,9 +307,7 @@
return tensor
-def find_state_dict_object(
- state_dict: STATE_DICT_TYPE, index: MetadataIndex
-) -> Any:
+def find_state_dict_object(state_dict: STATE_DICT_TYPE, index: MetadataIndex) -> Any:
if index.fqn not in state_dict:
raise ValueError(f"Could not find FQN: '{index.fqn}'")
obj = state_dict[index.fqn]