[DCP][BE] Apply ufmt to DCP and turn on lintrunner for DCP (#115302)

No logic change. Just typing and ufmt.

Differential Revision: [D51914982](https://our.internmc.facebook.com/intern/diff/D51914982/)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/115302
Approved by: https://github.com/XilunWu, https://github.com/wz337, https://github.com/LucasLLC
ghstack dependencies: #115523
diff --git a/.lintrunner.toml b/.lintrunner.toml
index ae54cfc..4a9ff17 100644
--- a/.lintrunner.toml
+++ b/.lintrunner.toml
@@ -869,7 +869,6 @@
     'test/bottleneck_test/**',  # excluded by test/run_test.py
     'test/distributed/argparse_util_test.py',
     'test/distributed/bin/test_script.py',
-    'test/distributed/checkpoint/e2e/test_pipeline.py',
     'test/distributed/elastic/agent/server/test/local_elastic_agent_test.py',
     'test/distributed/elastic/multiprocessing/bin/test_script.py',
     'test/distributed/elastic/multiprocessing/bin/zombie_test.py',
@@ -1094,19 +1093,6 @@
     'test/distributed/algorithms/test_join.py',
     'test/distributed/argparse_util_test.py',
     'test/distributed/bin/test_script.py',
-    'test/distributed/checkpoint/test_2d_fsdp_dt_checkpoint.py',
-    'test/distributed/checkpoint/test_checkpoint.py',
-    'test/distributed/checkpoint/test_dedup_tensors.py',
-    'test/distributed/checkpoint/test_dtensor_checkpoint.py',
-    'test/distributed/checkpoint/test_file_system_checkpoint.py',
-    'test/distributed/checkpoint/test_file_system_checkpoint_cpu.py',
-    'test/distributed/checkpoint/test_fsdp_model_state.py',
-    'test/distributed/checkpoint/test_fsdp_optim_state.py',
-    'test/distributed/checkpoint/test_fsspec.py',
-    'test/distributed/checkpoint/test_nested_dict.py',
-    'test/distributed/checkpoint/test_planner.py',
-    'test/distributed/checkpoint/test_traverse.py',
-    'test/distributed/checkpoint/test_utils.py',
     'test/distributed/elastic/agent/server/test/__init__.py',
     'test/distributed/elastic/agent/server/test/api_test.py',
     'test/distributed/elastic/agent/server/test/local_elastic_agent_test.py',
@@ -2010,25 +1996,6 @@
     'torch/distributed/autograd/__init__.py',
     'torch/distributed/benchmarks/benchmark_ddp_rpc.py',
     'torch/distributed/c10d_logger.py',
-    'torch/distributed/checkpoint/__init__.py',
-    'torch/distributed/checkpoint/_dedup_tensors.py',
-    'torch/distributed/checkpoint/_fsspec_filesystem.py',
-    'torch/distributed/checkpoint/_nested_dict.py',
-    'torch/distributed/checkpoint/_sharded_tensor_utils.py',
-    'torch/distributed/checkpoint/_traverse.py',
-    'torch/distributed/checkpoint/api.py',
-    'torch/distributed/checkpoint/default_planner.py',
-    'torch/distributed/checkpoint/examples/fsdp_checkpoint_example.py',
-    'torch/distributed/checkpoint/filesystem.py',
-    'torch/distributed/checkpoint/metadata.py',
-    'torch/distributed/checkpoint/optimizer.py',
-    'torch/distributed/checkpoint/planner.py',
-    'torch/distributed/checkpoint/planner_helpers.py',
-    'torch/distributed/checkpoint/resharding.py',
-    'torch/distributed/checkpoint/state_dict_loader.py',
-    'torch/distributed/checkpoint/state_dict_saver.py',
-    'torch/distributed/checkpoint/storage.py',
-    'torch/distributed/checkpoint/utils.py',
     'torch/distributed/collective_utils.py',
     'torch/distributed/constants.py',
     'torch/distributed/distributed_c10d.py',
@@ -2442,7 +2409,6 @@
     'torch/testing/_internal/distributed/_shard/test_common.py',
     'torch/testing/_internal/distributed/_tensor/__init__.py',
     'torch/testing/_internal/distributed/_tensor/common_dtensor.py',
-    'torch/testing/_internal/distributed/checkpoint_utils.py',
     'torch/testing/_internal/distributed/ddp_under_dist_autograd_test.py',
     'torch/testing/_internal/distributed/distributed_test.py',
     'torch/testing/_internal/distributed/distributed_utils.py',
diff --git a/test/distributed/checkpoint/e2e/test_pipeline.py b/test/distributed/checkpoint/e2e/test_pipeline.py
index cbf60b9..e197058 100644
--- a/test/distributed/checkpoint/e2e/test_pipeline.py
+++ b/test/distributed/checkpoint/e2e/test_pipeline.py
@@ -10,7 +10,7 @@
 from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
 from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
 from torch.testing._internal.common_fsdp import FSDPTest
-from torch.testing._internal.common_utils import TEST_WITH_DEV_DBG_ASAN
+from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
 from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
 
 
@@ -102,3 +102,7 @@
         self.assertTrue(os.path.exists(pipeline_dir))
         self.save_with_pipeline(pipeline_dir)
         self.load_with_fsdp(pipeline_dir)
+
+
+if __name__ == "__main__":
+    run_tests()
diff --git a/test/distributed/checkpoint/test_checkpoint.py b/test/distributed/checkpoint/test_checkpoint.py
index a2c9a29a..937853e 100644
--- a/test/distributed/checkpoint/test_checkpoint.py
+++ b/test/distributed/checkpoint/test_checkpoint.py
@@ -1,29 +1,28 @@
 # Owner(s): ["oncall: distributed"]
 
 import sys
-from typing import Optional, List, cast
-from torch.distributed.checkpoint.storage import WriteResult
-
-from torch.distributed.checkpoint import (
-    StorageReader,
-    StorageWriter,
-    CheckpointException,
-    load_state_dict,
-    save_state_dict,
-)
+from typing import cast, List, Optional
 
 import torch
 import torch.distributed as dist
-import torch.nn
 import torch.futures
-from torch.futures import Future
+import torch.nn
 
 from torch.distributed._shard import sharded_tensor
 
-from torch.distributed.checkpoint.default_planner import (
-    _create_default_local_metadata,
+from torch.distributed._shard.sharded_tensor import ShardedTensor, state_dict_hook
+from torch.distributed._shard.sharding_spec import ChunkShardingSpec
+
+from torch.distributed.checkpoint import (
+    CheckpointException,
+    load_state_dict,
+    save_state_dict,
+    StorageReader,
+    StorageWriter,
 )
 
+from torch.distributed.checkpoint.default_planner import _create_default_local_metadata
+
 from torch.distributed.checkpoint.metadata import (
     BytesStorageMetadata,
     Metadata,
@@ -31,31 +30,21 @@
 )
 
 from torch.distributed.checkpoint.planner import (
-    SavePlan,
-    SavePlanner,
     LoadPlan,
     LoadPlanner,
+    SavePlan,
+    SavePlanner,
 )
+from torch.distributed.checkpoint.storage import WriteResult
+from torch.futures import Future
+from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
 
-from torch.distributed._shard.sharded_tensor import (
-    state_dict_hook,
-    ShardedTensor,
-)
-from torch.distributed._shard.sharding_spec import ChunkShardingSpec
-from torch.testing._internal.common_distributed import (
-    requires_nccl,
-    skip_if_lt_x_gpu,
-)
+from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
 from torch.testing._internal.distributed._shard.sharded_tensor import (
     ShardedTensorTestBase,
     with_comms,
 )
 
-from torch.testing._internal.common_utils import (
-    TEST_WITH_DEV_DBG_ASAN,
-    run_tests,
-)
-
 if TEST_WITH_DEV_DBG_ASAN:
     print(
         "Skip dev-asan as torch + multiprocessing spawn have known issues",
@@ -175,9 +164,7 @@
         ranks = self._get_ranks(name)
         fut = Future()
         if ranks is not None and self.rank in ranks:
-            fut.set_exception(
-                ValueError(f"async rank fail {self.rank} for {name}")
-            )
+            fut.set_exception(ValueError(f"async rank fail {self.rank} for {name}"))
         else:
             fut.set_result(result)
         return fut
@@ -204,9 +191,7 @@
         self._fail_rank("fail_write_data")
         return self._fail_rank_async("fail_write_data_async", [])
 
-    def finish(
-        self, metadata: Metadata, results: List[List[WriteResult]]
-    ) -> None:
+    def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
         self._fail_rank("fail_finish")
 
 
@@ -239,9 +224,7 @@
     def get_spec(self):
         return ChunkShardingSpec(
             dim=0,
-            placements=[
-                f"rank:{r}/cuda:{r}" for r in range(dist.get_world_size())
-            ],
+            placements=[f"rank:{r}/cuda:{r}" for r in range(dist.get_world_size())],
         )
 
     @with_comms(init_rpc=False)
diff --git a/test/distributed/checkpoint/test_dedup_tensors.py b/test/distributed/checkpoint/test_dedup_tensors.py
index 6f2b81c..37525f6 100644
--- a/test/distributed/checkpoint/test_dedup_tensors.py
+++ b/test/distributed/checkpoint/test_dedup_tensors.py
@@ -1,12 +1,11 @@
 # Owner(s): ["oncall: distributed"]
 
 import dataclasses
+
 import torch
 from torch.distributed.checkpoint._dedup_tensors import dedup_tensors
 from torch.distributed.checkpoint.planner import SavePlan, WriteItemType
-from torch.distributed.checkpoint.planner_helpers import (
-    _create_write_item_for_tensor,
-)
+from torch.distributed.checkpoint.planner_helpers import _create_write_item_for_tensor
 from torch.testing._internal.common_utils import run_tests, TestCase
 
 
@@ -33,9 +32,7 @@
         self.assertEqual(2, len(dedup_plans[0].items))
         self.assertEqual(1, len(dedup_plans[1].items))
 
-        self.assertIn(
-            "tensor_0", (item.index.fqn for item in dedup_plans[0].items)
-        )
+        self.assertIn("tensor_0", (item.index.fqn for item in dedup_plans[0].items))
         self.assertIn("r0", (item.index.fqn for item in dedup_plans[0].items))
 
         self.assertIn("r1", (item.index.fqn for item in dedup_plans[1].items))
diff --git a/test/distributed/checkpoint/test_dtensor_checkpoint.py b/test/distributed/checkpoint/test_dtensor_checkpoint.py
index 5f664bd..8c4a1ff 100644
--- a/test/distributed/checkpoint/test_dtensor_checkpoint.py
+++ b/test/distributed/checkpoint/test_dtensor_checkpoint.py
@@ -6,19 +6,19 @@
 import torch.distributed.checkpoint as dist_cp
 from torch.distributed._tensor import (
     DeviceMesh,
+    distribute_tensor,
     DTensor,
     Replicate,
     Shard,
-    distribute_tensor,
     zeros,
 )
-from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
+from torch.testing._internal.common_utils import run_tests
 from torch.testing._internal.distributed._tensor.common_dtensor import (
     DTensorTestBase,
     skip_if_lt_x_gpu,
     with_comms,
 )
-from torch.testing._internal.common_utils import run_tests
+from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
 
 
 SUBMESH_TENSOR_SIZE = 6
@@ -81,9 +81,7 @@
             device_type=self.device_type,
             mesh=range(dist.get_world_size()),
         )
-        sharded_dt = distribute_tensor(
-            tensor_to_shard, mesh, placements=[Shard(0)]
-        )
+        sharded_dt = distribute_tensor(tensor_to_shard, mesh, placements=[Shard(0)])
         replicated_dt = distribute_tensor(
             tensor_to_replicate, mesh, placements=[Replicate()]
         )
@@ -179,9 +177,7 @@
             storage_writer=dist_cp.FileSystemWriter(path=CHECKPOINT_DIR),
             planner=dist_cp.DefaultSavePlanner(),
         )
-        model, _, _ = self.create_dtensor_model(
-            local_tensor * 10, local_tensor_2 * 10
-        )
+        model, _, _ = self.create_dtensor_model(local_tensor * 10, local_tensor_2 * 10)
         state_dict = model.state_dict()
 
         """
@@ -247,9 +243,7 @@
             if k == "submesh_sdt":
                 if self.rank % 2 == 0:
                     shard_size = int(SUBMESH_TENSOR_SIZE / v.device_mesh.size())
-                    self.assertEqual(
-                        v.to_local().size(), torch.Size([shard_size])
-                    )
+                    self.assertEqual(v.to_local().size(), torch.Size([shard_size]))
                     self.assertEqual(v.to_local(), torch.zeros([shard_size]))
                 else:
                     self.assertEqual(v.to_local().size(), torch.Size([0]))
@@ -258,9 +252,7 @@
             if k == "submesh_rdt":
                 if self.rank % 2 == 0:
                     shard_size = SUBMESH_TENSOR_SIZE
-                    self.assertEqual(
-                        v.to_local().size(), torch.Size([shard_size])
-                    )
+                    self.assertEqual(v.to_local().size(), torch.Size([shard_size]))
                     self.assertEqual(v.to_local(), torch.zeros([shard_size]))
                 else:
                     self.assertEqual(v.to_local().size(), torch.Size([0]))
diff --git a/test/distributed/checkpoint/test_file_system_checkpoint.py b/test/distributed/checkpoint/test_file_system_checkpoint.py
index 3d92e79..d33b60a 100644
--- a/test/distributed/checkpoint/test_file_system_checkpoint.py
+++ b/test/distributed/checkpoint/test_file_system_checkpoint.py
@@ -1,42 +1,21 @@
 # Owner(s): ["oncall: distributed"]
 
 import os
-import sys
 import shutil
+import sys
 import tempfile
 from typing import Dict
 
 import torch
 import torch.distributed as dist
 from torch.distributed._shard import sharded_tensor
-from torch.distributed._shard.sharded_tensor import (
-    ShardedTensor,
-    state_dict_hook,
-)
+from torch.distributed._shard.sharded_tensor import ShardedTensor, state_dict_hook
 from torch.distributed._shard.sharding_spec import (
     ChunkShardingSpec,
     EnumerableShardingSpec,
     ShardingSpec,
     ShardMetadata,
 )
-from torch.testing._internal.common_distributed import (
-    requires_nccl,
-    skip_if_lt_x_gpu,
-)
-from torch.testing._internal.common_utils import TestCase
-from torch.testing._internal.distributed._shard.sharded_tensor import (
-    ShardedTensorTestBase,
-    with_comms,
-)
-from torch.testing._internal.distributed._shard.sharded_tensor._test_st_common import (
-    MyShardedModel1,
-)
-
-
-from torch.testing._internal.common_utils import (
-    TEST_WITH_DEV_DBG_ASAN,
-    run_tests,
-)
 
 from torch.distributed.checkpoint import (
     FileSystemReader,
@@ -44,6 +23,20 @@
     load_state_dict,
     save_state_dict,
 )
+from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
+
+from torch.testing._internal.common_utils import (
+    run_tests,
+    TEST_WITH_DEV_DBG_ASAN,
+    TestCase,
+)
+from torch.testing._internal.distributed._shard.sharded_tensor import (
+    ShardedTensorTestBase,
+    with_comms,
+)
+from torch.testing._internal.distributed._shard.sharded_tensor._test_st_common import (
+    MyShardedModel1,
+)
 
 
 if TEST_WITH_DEV_DBG_ASAN:
@@ -122,9 +115,7 @@
             state_dict_to_load_to = MyTestModule().state_dict()
 
             with self.assertRaises(AssertionError):
-                assert_state_dict_equal(
-                    self, state_dict_to_load_to, state_dict_to_save
-                )
+                assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
 
             # Load from file without any resharding
             fs_reader = FileSystemReader(path=path)
@@ -134,9 +125,7 @@
                 no_dist=True,
             )
 
-            assert_state_dict_equal(
-                self, state_dict_to_load_to, state_dict_to_save
-            )
+            assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
 
         with tempfile.TemporaryDirectory() as path:
             state_dict_to_save = MyTestModule().state_dict()
@@ -151,9 +140,7 @@
             state_dict_to_load_to = MyTestModule().state_dict()
 
             with self.assertRaises(AssertionError):
-                assert_state_dict_equal(
-                    self, state_dict_to_load_to, state_dict_to_save
-                )
+                assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
 
             # Load from file without any resharding
             fs_reader = FileSystemReader(path=path)
@@ -163,9 +150,7 @@
                 no_dist=True,
             )
 
-            assert_state_dict_equal(
-                self, state_dict_to_load_to, state_dict_to_save
-            )
+            assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
 
 
 class TestDistributedStateDictSaveLoadWithSharedTensor(ShardedTensorTestBase):
@@ -212,15 +197,11 @@
         dist.barrier()
 
         with self.assertRaises(AssertionError):
-            assert_state_dict_equal(
-                self, state_dict_to_load_to, state_dict_to_save
-            )
+            assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
 
         # Test load.
         fs_reader = FileSystemReader(path=path)
-        load_state_dict(
-            state_dict=state_dict_to_load_to, storage_reader=fs_reader
-        )
+        load_state_dict(state_dict=state_dict_to_load_to, storage_reader=fs_reader)
 
         assert_state_dict_equal(self, state_dict_to_load_to, state_dict_to_save)
         dist.barrier()
@@ -238,9 +219,7 @@
 
     def load_tensor(self, tensor: ShardedTensor) -> torch.Tensor:
         res = (
-            torch.zeros(tensor.shape, device="cuda:0")
-            if dist.get_rank() == 0
-            else None
+            torch.zeros(tensor.shape, device="cuda:0") if dist.get_rank() == 0 else None
         )
         tensor.gather(out=res)
         return res
@@ -335,9 +314,7 @@
                 state_dict_to_save = model_to_save.state_dict()
 
                 fs_writer = FileSystemWriter(path=path)
-                save_state_dict(
-                    state_dict=state_dict_to_save, storage_writer=fs_writer
-                )
+                save_state_dict(state_dict=state_dict_to_save, storage_writer=fs_writer)
 
                 dist.barrier()
 
@@ -404,9 +381,7 @@
 
         fs_reader = FileSystemReader(path=path)
 
-        load_state_dict(
-            state_dict=state_dict_to_load_to, storage_reader=fs_reader
-        )
+        load_state_dict(state_dict=state_dict_to_load_to, storage_reader=fs_reader)
 
         # We can't use torch.allclose since each ST has a different sharding spec
         store_tensor = self.load_tensor(model_to_save.sharded_tensor)
@@ -516,9 +491,7 @@
                         f"save-spec {save_spec} load-spec {load_spec}",
                     )
                     self.assertTrue(
-                        torch.allclose(
-                            save_dict["replicated"], load_dict_replicated
-                        ),
+                        torch.allclose(save_dict["replicated"], load_dict_replicated),
                         f"save-spec {save_spec} load-spec {load_spec}",
                     )
 
diff --git a/test/distributed/checkpoint/test_fsdp_model_state.py b/test/distributed/checkpoint/test_fsdp_model_state.py
index 6cacf7c..9ccf36c 100644
--- a/test/distributed/checkpoint/test_fsdp_model_state.py
+++ b/test/distributed/checkpoint/test_fsdp_model_state.py
@@ -1,23 +1,23 @@
 # Owner(s): ["oncall: distributed"]
 
 import torch
+import torch.distributed as dist
+import torch.distributed.checkpoint as dist_cp
+
+from torch.distributed.checkpoint.default_planner import (
+    DefaultLoadPlanner,
+    DefaultSavePlanner,
+)
 
 from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
 from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
-import torch.distributed.checkpoint as dist_cp
-import torch.distributed as dist
-
-from torch.distributed.checkpoint.default_planner import (
-    DefaultSavePlanner,
-    DefaultLoadPlanner,
-)
+from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
+from torch.testing._internal.common_utils import run_tests
 
 from torch.testing._internal.distributed._tensor.common_dtensor import (
     DTensorTestBase,
     with_comms,
 )
-from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
-from torch.testing._internal.common_utils import run_tests
 from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
 
 
diff --git a/test/distributed/checkpoint/test_fsspec.py b/test/distributed/checkpoint/test_fsspec.py
index 14a4de9..b5d4195 100644
--- a/test/distributed/checkpoint/test_fsspec.py
+++ b/test/distributed/checkpoint/test_fsspec.py
@@ -9,23 +9,12 @@
 import torch.distributed as dist
 import torch.distributed.checkpoint as dcp
 import torch.nn as nn
-from torch.distributed.checkpoint._fsspec_filesystem import (
-    FsspecReader,
-    FsspecWriter,
-)
-from torch.distributed.checkpoint.optimizer import (
-    load_sharded_optimizer_state_dict,
-)
+from torch.distributed.checkpoint._fsspec_filesystem import FsspecReader, FsspecWriter
+from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
 from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
 from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
-from torch.testing._internal.common_distributed import (
-    requires_nccl,
-    skip_if_lt_x_gpu,
-)
-from torch.testing._internal.common_utils import (
-    run_tests,
-    TestCase,
-)
+from torch.testing._internal.common_distributed import requires_nccl, skip_if_lt_x_gpu
+from torch.testing._internal.common_utils import run_tests, TestCase
 from torch.testing._internal.distributed._shard.sharded_tensor import (
     ShardedTensorTestBase,
     with_comms,
@@ -182,9 +171,7 @@
             return list(iter(opt.state.values()))[idx]
 
         # Adam lazily creates its state
-        self.assertEqual(
-            opt_at(optim, 0)["exp_avg"], opt_at(optim_2, 0)["exp_avg"]
-        )
+        self.assertEqual(opt_at(optim, 0)["exp_avg"], opt_at(optim_2, 0)["exp_avg"])
         self.assertEqual(
             opt_at(optim, 0)["exp_avg_sq"], opt_at(optim_2, 0)["exp_avg_sq"]
         )
diff --git a/test/distributed/checkpoint/test_nested_dict.py b/test/distributed/checkpoint/test_nested_dict.py
index 115982e..12df302 100644
--- a/test/distributed/checkpoint/test_nested_dict.py
+++ b/test/distributed/checkpoint/test_nested_dict.py
@@ -1,11 +1,11 @@
 # Owner(s): ["oncall: distributed"]
 
 import torch
-from torch.testing._internal.common_utils import run_tests, TestCase
 from torch.distributed.checkpoint._nested_dict import (
     flatten_state_dict,
     unflatten_state_dict,
 )
+from torch.testing._internal.common_utils import run_tests, TestCase
 
 
 class TestFlattening(TestCase):
diff --git a/test/distributed/checkpoint/test_planner.py b/test/distributed/checkpoint/test_planner.py
index a8563c7..53c129d 100644
--- a/test/distributed/checkpoint/test_planner.py
+++ b/test/distributed/checkpoint/test_planner.py
@@ -3,43 +3,45 @@
 import sys
 
 import torch
-from torch.distributed.checkpoint.planner import LoadItemType, WriteItemType
 
 from torch.distributed._shard.sharded_tensor import (
     Shard,
-    ShardMetadata,
     ShardedTensor,
     ShardedTensorMetadata,
+    ShardMetadata,
 )
 from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
+from torch.distributed.checkpoint._dedup_tensors import dedup_tensors
 
-from torch.testing._internal.common_utils import (
-    TestCase,
-    TEST_WITH_DEV_DBG_ASAN,
-    run_tests,
+from torch.distributed.checkpoint.default_planner import (
+    _create_default_local_metadata,
+    create_default_global_save_plan,
+    create_default_local_load_plan,
+    create_default_local_save_plan,
 )
 from torch.distributed.checkpoint.metadata import (
     BytesStorageMetadata,
+    ChunkStorageMetadata,
     MetadataIndex,
     TensorStorageMetadata,
-    ChunkStorageMetadata,
+)
+from torch.distributed.checkpoint.planner import LoadItemType, WriteItemType
+
+from torch.distributed.checkpoint.planner_helpers import (
+    create_read_items_for_chunk_list,
+)
+
+from torch.testing._internal.common_utils import (
+    run_tests,
+    TEST_WITH_DEV_DBG_ASAN,
+    TestCase,
 )
 
 from torch.testing._internal.distributed.distributed_utils import (
+    with_dist,
     with_fake_comms,
-    with_dist
 )
 
-from torch.distributed.checkpoint.default_planner import (
-    create_default_global_save_plan,
-    create_default_local_save_plan,
-    create_default_local_load_plan,
-    _create_default_local_metadata
-)
-
-from torch.distributed.checkpoint.planner_helpers import create_read_items_for_chunk_list
-from torch.distributed.checkpoint._dedup_tensors import dedup_tensors
-
 
 if TEST_WITH_DEV_DBG_ASAN:
     print(
@@ -48,30 +50,34 @@
     )
     sys.exit(0)
 
+
 def create_sharded_tensor(rank, world_size, shards_per_rank, shard_size=8):
     shards_metadata = []
     local_shards = []
     for idx in range(0, world_size * shards_per_rank):
         shard_rank = idx // shards_per_rank
-        shard_md = ShardMetadata(shard_offsets=[idx * shard_size], shard_sizes=[shard_size], placement=f"rank:{shard_rank}/cpu")
+        shard_md = ShardMetadata(
+            shard_offsets=[idx * shard_size],
+            shard_sizes=[shard_size],
+            placement=f"rank:{shard_rank}/cpu",
+        )
         shards_metadata.append(shard_md)
         if shard_rank == rank:
             shard = Shard.from_tensor_and_offsets(
                 torch.rand(*shard_md.shard_sizes),
                 shard_offsets=shard_md.shard_offsets,
-                rank=rank
+                rank=rank,
             )
             local_shards.append(shard)
 
     sharded_tensor_md = ShardedTensorMetadata(
         shards_metadata=shards_metadata,
         size=torch.Size([shard_size * len(shards_metadata)]),
-        tensor_properties=TensorProperties.create_from_tensor(torch.zeros(1))
+        tensor_properties=TensorProperties.create_from_tensor(torch.zeros(1)),
     )
 
     return ShardedTensor._init_from_local_shards_and_global_metadata(
-        local_shards=local_shards,
-        sharded_tensor_metadata=sharded_tensor_md
+        local_shards=local_shards, sharded_tensor_metadata=sharded_tensor_md
     )
 
 
@@ -81,18 +87,17 @@
         tensor = torch.rand(10)
         val = [1, 2, 3]
         st = create_sharded_tensor(rank=1, world_size=4, shards_per_rank=1)
-        state_dict = {
-            "tensor": tensor,
-            "value": val,
-            "st": st
-        }
+        state_dict = {"tensor": tensor, "value": val, "st": st}
         plan = create_default_local_save_plan(state_dict, False)
         self.assertEqual(2, len(plan.items))
         wi = plan.items[0]
         self.assertEqual(wi.index, MetadataIndex("tensor", [0]))
         self.assertEqual(wi.type, WriteItemType.TENSOR)
         self.assertEqual(wi.tensor_data.size, tensor.size())
-        self.assertEqual(wi.tensor_data.properties, TensorProperties.create_from_tensor(torch.zeros(1)))
+        self.assertEqual(
+            wi.tensor_data.properties,
+            TensorProperties.create_from_tensor(torch.zeros(1)),
+        )
         self.assertEqual(wi.tensor_data.chunk.offsets, torch.Size([0]))
         self.assertEqual(wi.tensor_data.chunk.sizes, torch.Size([10]))
 
@@ -100,7 +105,10 @@
         self.assertEqual(st_wi.index, MetadataIndex("st", [8]))
         self.assertEqual(st_wi.type, WriteItemType.SHARD)
         self.assertEqual(st_wi.tensor_data.size, st.size())
-        self.assertEqual(st_wi.tensor_data.properties, TensorProperties.create_from_tensor(torch.zeros(1)))
+        self.assertEqual(
+            st_wi.tensor_data.properties,
+            TensorProperties.create_from_tensor(torch.zeros(1)),
+        )
         self.assertEqual(st_wi.tensor_data.chunk.offsets, torch.Size([8]))
         self.assertEqual(st_wi.tensor_data.chunk.sizes, torch.Size([8]))
 
@@ -111,7 +119,10 @@
         tensor_wi = next(wi for wi in plan.items if wi.type == WriteItemType.TENSOR)
         self.assertEqual(tensor_wi.index, MetadataIndex("tensor", [0]))
         self.assertEqual(tensor_wi.tensor_data.size, tensor.size())
-        self.assertEqual(tensor_wi.tensor_data.properties, TensorProperties.create_from_tensor(tensor))
+        self.assertEqual(
+            tensor_wi.tensor_data.properties,
+            TensorProperties.create_from_tensor(tensor),
+        )
         self.assertEqual(tensor_wi.tensor_data.chunk.offsets, torch.Size([0]))
         self.assertEqual(tensor_wi.tensor_data.chunk.sizes, torch.Size([10]))
 
@@ -125,11 +136,7 @@
                 tensor = torch.rand(10)
                 val = [1, 2, 3]
                 st = create_sharded_tensor(rank=rank, world_size=4, shards_per_rank=1)
-                state_dict = {
-                    "tensor": tensor,
-                    "value": val,
-                    "st": st
-                }
+                state_dict = {"tensor": tensor, "value": val, "st": st}
                 return create_default_local_save_plan(state_dict, rank == 0)
 
         all_plans = [create_data(0), create_data(1), create_data(2), create_data(3)]
@@ -150,11 +157,15 @@
                 else:
                     self.assertTrue(isinstance(item_md, TensorStorageMetadata))
                     self.assertEqual(item_md.size, old_item.tensor_data.size)
-                    self.assertEqual(item_md.properties, old_item.tensor_data.properties)
+                    self.assertEqual(
+                        item_md.properties, old_item.tensor_data.properties
+                    )
 
                     self.assertIsNotNone(new_item.index.index)
                     # Make sure the hint is correct
-                    self.assertEqual(item_md.chunks[new_item.index.index], old_item.tensor_data.chunk)
+                    self.assertEqual(
+                        item_md.chunks[new_item.index.index], old_item.tensor_data.chunk
+                    )
 
     def test_local_load_plan(self):
         def create_state_dict(rank):
@@ -162,11 +173,7 @@
                 tensor = torch.rand(10)
                 val = [1, 2, 3]
                 st = create_sharded_tensor(rank=rank, world_size=4, shards_per_rank=1)
-                return {
-                    "tensor": tensor,
-                    "value": val,
-                    "st": st
-                }
+                return {"tensor": tensor, "value": val, "st": st}
 
         state_dict = create_state_dict(1)
         metadata = _create_default_local_metadata(state_dict)
@@ -175,7 +182,9 @@
         # This will create 3 entries
         self.assertEqual(3, len(load_plan.items))
         st_item = next(ri for ri in load_plan.items if ri.dest_index.fqn == "st")
-        tensor_item = next(ri for ri in load_plan.items if ri.dest_index.fqn == "tensor")
+        tensor_item = next(
+            ri for ri in load_plan.items if ri.dest_index.fqn == "tensor"
+        )
         bytes_item = next(ri for ri in load_plan.items if ri.dest_index.fqn == "value")
 
         self.assertEqual(st_item.type, LoadItemType.TENSOR)
@@ -208,7 +217,6 @@
                     )
                 }
 
-
         # Rank 1 has a 16 bytes shard from [16, 32[
         world8_state_dict = create_state_dict(rank=1, world_size=8)
         world8_metadata = _create_default_local_metadata(world8_state_dict)
@@ -221,8 +229,12 @@
         # Each 4-world shard has 32 elements, so it needs to load 2 shards
         load_plan = create_default_local_load_plan(world4_state_dict, world8_metadata)
         self.assertEqual(2, len(load_plan.items))
-        low_ri = next(ri for ri in load_plan.items if ri.dest_offsets == torch.Size([0]))
-        high_ri = next(ri for ri in load_plan.items if ri.dest_offsets == torch.Size([16]))
+        low_ri = next(
+            ri for ri in load_plan.items if ri.dest_offsets == torch.Size([0])
+        )
+        high_ri = next(
+            ri for ri in load_plan.items if ri.dest_offsets == torch.Size([16])
+        )
 
         self.assertEqual(low_ri.storage_index, MetadataIndex("st", [32]))
         self.assertEqual(low_ri.storage_offsets, torch.Size([0]))
@@ -258,6 +270,7 @@
                         shard_size=120 // world_size,
                     )
                 }
+
         # rank 1 has a 30 bytes shard from [30, 60[
         world4_state_dict = create_state_dict(rank=1, world_size=4)
         world4_metadata = _create_default_local_metadata(world4_state_dict)
@@ -268,9 +281,13 @@
         load_plan = create_default_local_load_plan(world3_state_dict, world4_metadata)
         self.assertEqual(2, len(load_plan.items))
         # this is [30, 60] to load [40, 60]
-        low_ri = next(ri for ri in load_plan.items if ri.dest_offsets == torch.Size([0]))
+        low_ri = next(
+            ri for ri in load_plan.items if ri.dest_offsets == torch.Size([0])
+        )
         # this is [60, 90] to load [60, 80]
-        high_ri = next(ri for ri in load_plan.items if ri.dest_offsets == torch.Size([20]))
+        high_ri = next(
+            ri for ri in load_plan.items if ri.dest_offsets == torch.Size([20])
+        )
 
         self.assertEqual(low_ri.storage_index, MetadataIndex("st", [30]))
         self.assertEqual(low_ri.storage_offsets, torch.Size([10]))
@@ -284,27 +301,19 @@
         self.assertEqual(high_ri.dest_offsets, torch.Size([20]))
         self.assertEqual(high_ri.lengths, torch.Size([20]))
 
+
 class TestPlannerHelpers(TestCase):
     def test_create_read_item_from_chunks(self):
         tensor_md = TensorStorageMetadata(
             properties=TensorProperties.create_from_tensor(torch.empty([16])),
             size=torch.Size([16]),
             chunks=[
-                ChunkStorageMetadata(
-                    offsets=torch.Size([0]),
-                    sizes=torch.Size([8])
-                ),
-                ChunkStorageMetadata(
-                    offsets=torch.Size([8]),
-                    sizes=torch.Size([8])
-                )
-            ]
+                ChunkStorageMetadata(offsets=torch.Size([0]), sizes=torch.Size([8])),
+                ChunkStorageMetadata(offsets=torch.Size([8]), sizes=torch.Size([8])),
+            ],
         )
 
-        chunk = ChunkStorageMetadata(
-            offsets=torch.Size([4]),
-            sizes=torch.Size([7])
-        )
+        chunk = ChunkStorageMetadata(offsets=torch.Size([4]), sizes=torch.Size([7]))
         read_items = create_read_items_for_chunk_list("foo", tensor_md, [chunk])
 
         self.assertEqual(2, len(read_items))
@@ -316,7 +325,6 @@
 
         self.assertEqual(torch.Size([4]), read_items[0].lengths)
 
-
         self.assertEqual(MetadataIndex("foo", [4]), read_items[1].dest_index)
         self.assertEqual(torch.Size([4]), read_items[1].dest_offsets)
 
@@ -325,5 +333,6 @@
 
         self.assertEqual(torch.Size([3]), read_items[1].lengths)
 
+
 if __name__ == "__main__":
     run_tests()
diff --git a/test/distributed/checkpoint/test_traverse.py b/test/distributed/checkpoint/test_traverse.py
index 3a47311..4755967 100644
--- a/test/distributed/checkpoint/test_traverse.py
+++ b/test/distributed/checkpoint/test_traverse.py
@@ -1,6 +1,7 @@
 # Owner(s): ["oncall: distributed"]
 
 from collections import OrderedDict
+
 import torch
 
 import torch.distributed.checkpoint._traverse as _traverse
@@ -95,9 +96,7 @@
         self.assertEqual(data[("key0", "key2")], torch.tensor([1]))
 
     def test_traverse_doesnt_ignore_intermediate_collections(self) -> None:
-        state_dict: STATE_DICT_TYPE = {
-            "key0": [{"key1": {"key2": torch.tensor([1])}}]
-        }
+        state_dict: STATE_DICT_TYPE = {"key0": [{"key1": {"key2": torch.tensor([1])}}]}
 
         data = {}
 
diff --git a/test/distributed/checkpoint/test_utils.py b/test/distributed/checkpoint/test_utils.py
index e2b4aac..78d97f0 100644
--- a/test/distributed/checkpoint/test_utils.py
+++ b/test/distributed/checkpoint/test_utils.py
@@ -6,22 +6,20 @@
 
 from torch.distributed._shard.sharded_tensor import (
     Shard,
-    ShardMetadata,
     ShardedTensor,
     ShardedTensorMetadata,
+    ShardMetadata,
 )
 from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
+from torch.distributed.checkpoint.metadata import MetadataIndex
+from torch.distributed.checkpoint.utils import find_state_dict_object
 
 from torch.testing._internal.common_utils import (
-    TestCase,
-    TEST_WITH_DEV_DBG_ASAN,
     run_tests,
+    TEST_WITH_DEV_DBG_ASAN,
+    TestCase,
 )
-from torch.distributed.checkpoint.utils import find_state_dict_object
-from torch.distributed.checkpoint.metadata import MetadataIndex
-from torch.testing._internal.distributed.distributed_utils import (
-    with_fake_comms
-)
+from torch.testing._internal.distributed.distributed_utils import with_fake_comms
 
 if TEST_WITH_DEV_DBG_ASAN:
     print(
@@ -30,30 +28,32 @@
     )
     sys.exit(0)
 
+
 def create_sharded_tensor(rank, world_size, shards_per_rank):
     shards_metadata = []
     local_shards = []
     for idx in range(0, world_size * shards_per_rank):
         shard_rank = idx // shards_per_rank
-        shard_md = ShardMetadata(shard_offsets=[idx * 8], shard_sizes=[8], placement=f"rank:{shard_rank}/cpu")
+        shard_md = ShardMetadata(
+            shard_offsets=[idx * 8], shard_sizes=[8], placement=f"rank:{shard_rank}/cpu"
+        )
         shards_metadata.append(shard_md)
         if shard_rank == rank:
             shard = Shard.from_tensor_and_offsets(
                 torch.rand(*shard_md.shard_sizes),
                 shard_offsets=shard_md.shard_offsets,
-                rank=rank
+                rank=rank,
             )
             local_shards.append(shard)
 
     sharded_tensor_md = ShardedTensorMetadata(
         shards_metadata=shards_metadata,
         size=torch.Size([8 * len(shards_metadata)]),
-        tensor_properties=TensorProperties.create_from_tensor(torch.zeros(1))
+        tensor_properties=TensorProperties.create_from_tensor(torch.zeros(1)),
     )
 
     return ShardedTensor._init_from_local_shards_and_global_metadata(
-        local_shards=local_shards,
-        sharded_tensor_metadata=sharded_tensor_md
+        local_shards=local_shards, sharded_tensor_metadata=sharded_tensor_md
     )
 
 
diff --git a/torch/distributed/checkpoint/__init__.py b/torch/distributed/checkpoint/__init__.py
index c68747a..b45d0c2 100644
--- a/torch/distributed/checkpoint/__init__.py
+++ b/torch/distributed/checkpoint/__init__.py
@@ -1,23 +1,15 @@
+from .api import CheckpointException
+from .checkpointer import Checkpointer
+from .default_planner import DefaultLoadPlanner, DefaultSavePlanner
+from .filesystem import FileSystemCheckpointer, FileSystemReader, FileSystemWriter
 from .metadata import (
-    TensorStorageMetadata,
     BytesStorageMetadata,
     ChunkStorageMetadata,
     Metadata,
+    TensorStorageMetadata,
 )
-from .state_dict_loader import load_state_dict, load
-from .state_dict_saver import save_state_dict, save
-from .storage import StorageReader, StorageWriter
-from .checkpointer import Checkpointer
-from .filesystem import FileSystemReader, FileSystemWriter, FileSystemCheckpointer
-from .api import CheckpointException
-
-from .planner import (
-    SavePlanner,
-    LoadPlanner,
-    SavePlan,
-    LoadPlan,
-    ReadItem,
-    WriteItem,
-)
-from .default_planner import DefaultSavePlanner, DefaultLoadPlanner
 from .optimizer import load_sharded_optimizer_state_dict
+from .planner import LoadPlan, LoadPlanner, ReadItem, SavePlan, SavePlanner, WriteItem
+from .state_dict_loader import load, load_state_dict
+from .state_dict_saver import save, save_state_dict
+from .storage import StorageReader, StorageWriter
diff --git a/torch/distributed/checkpoint/_dedup_tensors.py b/torch/distributed/checkpoint/_dedup_tensors.py
index cf7ca01..a6be92c 100644
--- a/torch/distributed/checkpoint/_dedup_tensors.py
+++ b/torch/distributed/checkpoint/_dedup_tensors.py
@@ -23,8 +23,10 @@
     logger.propagate = False
     return logger
 
+
 logger = init_logger()
 
+
 # TODO add docstring for dedup_tensors
 def dedup_tensors(all_plans: List[SavePlan]) -> List[SavePlan]:
     all_plans = list(all_plans)
@@ -51,8 +53,6 @@
             for write_item in all_plans[plan_idx].items
             if write_item.index not in key_set
         ]
-        all_plans[plan_idx] = dataclasses.replace(
-            all_plans[plan_idx], items=new_items
-        )
+        all_plans[plan_idx] = dataclasses.replace(all_plans[plan_idx], items=new_items)
 
     return all_plans
diff --git a/torch/distributed/checkpoint/_fsspec_filesystem.py b/torch/distributed/checkpoint/_fsspec_filesystem.py
index 8703d1b..5c98845 100644
--- a/torch/distributed/checkpoint/_fsspec_filesystem.py
+++ b/torch/distributed/checkpoint/_fsspec_filesystem.py
@@ -9,20 +9,18 @@
 import queue
 import threading
 from abc import ABC, abstractmethod
-
 from dataclasses import dataclass
 from typing import Callable, cast, Dict, List, Optional, Union
 
 import fsspec
-import torch
 from fsspec import AbstractFileSystem
 from fsspec.core import url_to_fs
+
+import torch
 from torch import Tensor
 from torch._utils import _get_device_module
-
 from torch.distributed._shard._utils import narrow_tensor_by_index
 from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex
-
 from torch.distributed.checkpoint.planner import (
     LoadItemType,
     LoadPlan,
@@ -145,10 +143,7 @@
 
     def _refill(self):
         with self.device_module.stream(self.stream):
-            while (
-                not self._done
-                and self.in_flight_data < self.inflight_threshhold
-            ):
+            while not self._done and self.in_flight_data < self.inflight_threshhold:
                 _, obj = self.items[self.idx]
                 self.idx += 1
                 tensor = self.resolve_fun(obj).detach()
@@ -206,9 +201,7 @@
     return size * torch._utils._element_size(dtype)
 
 
-def _split_by_size_and_type(
-    bins: int, items: List[WriteItem]
-) -> List[List[WriteItem]]:
+def _split_by_size_and_type(bins: int, items: List[WriteItem]) -> List[List[WriteItem]]:
     if bins == 1:
         return [items]
 
@@ -276,16 +269,12 @@
                     planner.resolve_data,
                 )
 
-            tensor_w = [
-                wi for wi in write_items if wi.type != WriteItemType.BYTE_IO
-            ]
+            tensor_w = [wi for wi in write_items if wi.type != WriteItemType.BYTE_IO]
             for write_item in tensor_w:
                 loader.add(_item_size(write_item), write_item)
             loader.start_loading()
 
-            bytes_w = [
-                wi for wi in write_items if wi.type == WriteItemType.BYTE_IO
-            ]
+            bytes_w = [wi for wi in write_items if wi.type == WriteItemType.BYTE_IO]
             write_results = []
 
             with fs.transaction:
@@ -351,9 +340,7 @@
         self.fs.makedirs(self.path, exist_ok=True)
         return plan
 
-    def prepare_global_plan(
-        self, global_plan: List[SavePlan]
-    ) -> List[SavePlan]:
+    def prepare_global_plan(self, global_plan: List[SavePlan]) -> List[SavePlan]:
         new_plans = [
             dataclasses.replace(plan, storage_data=_StoragePrefix(f"__{i}_"))
             for i, plan in enumerate(global_plan)
@@ -376,9 +363,7 @@
 
         file_queue: queue.Queue = queue.Queue()
         if self.single_file_per_rank:
-            for bucket in _split_by_size_and_type(
-                self.thread_count, plan.items
-            ):
+            for bucket in _split_by_size_and_type(self.thread_count, plan.items):
                 file_name = gen_file()
                 file_path = os.path.join(self.path, file_name)
                 file_queue.put((file_path, file_name, bucket))
@@ -427,9 +412,7 @@
             fut.set_result(res)
             return fut
 
-    def finish(
-        self, metadata: Metadata, results: List[List[WriteResult]]
-    ) -> None:
+    def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
         storage_md = dict()
         for wr_list in results:
             storage_md.update({wr.index: wr.storage_data for wr in wr_list})
@@ -495,16 +478,12 @@
         with fsspec.open(metadata_path, "rb") as metadata_file:
             return pickle.load(metadata_file)
 
-    def set_up_storage_reader(
-        self, metadata: Metadata, is_coordinator: bool
-    ) -> None:
+    def set_up_storage_reader(self, metadata: Metadata, is_coordinator: bool) -> None:
         self.storage_data = metadata.storage_data
         assert self.storage_data is not None
 
     def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
         return plan
 
-    def prepare_global_plan(
-        self, global_plan: List[LoadPlan]
-    ) -> List[LoadPlan]:
+    def prepare_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]:
         return global_plan
diff --git a/torch/distributed/checkpoint/_nested_dict.py b/torch/distributed/checkpoint/_nested_dict.py
index 6d2e0bc..527a67e 100644
--- a/torch/distributed/checkpoint/_nested_dict.py
+++ b/torch/distributed/checkpoint/_nested_dict.py
@@ -1,16 +1,9 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates
 from typing import Dict, Tuple
 
-from torch.distributed.checkpoint.metadata import (
-    STATE_DICT_TYPE,
-)
+from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
 
-from ._traverse import (
-    traverse_state_dict,
-    set_element,
-    OBJ_PATH,
-    STATE_DICT_ITEM,
-)
+from ._traverse import OBJ_PATH, set_element, STATE_DICT_ITEM, traverse_state_dict
 
 """
 TODO:
diff --git a/torch/distributed/checkpoint/_sharded_tensor_utils.py b/torch/distributed/checkpoint/_sharded_tensor_utils.py
index 4fff382..582dfc0 100644
--- a/torch/distributed/checkpoint/_sharded_tensor_utils.py
+++ b/torch/distributed/checkpoint/_sharded_tensor_utils.py
@@ -3,29 +3,12 @@
 import copy
 
 import torch.distributed as dist
+from torch.distributed._shard.sharded_tensor import Shard, ShardedTensor, ShardMetadata
+from torch.distributed._shard.sharded_tensor.metadata import ShardedTensorMetadata
+from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
 from torch.distributed.remote_device import _remote_device
 
-from torch.distributed.checkpoint.metadata import (
-    STATE_DICT_TYPE,
-)
-from torch.distributed._shard.sharded_tensor import (
-    Shard,
-    ShardMetadata,
-    ShardedTensor,
-)
-
-from torch.distributed._shard.sharded_tensor.metadata import (
-    ShardedTensorMetadata,
-)
-
-
-from ._traverse import (
-    OBJ_PATH,
-    traverse_state_dict,
-    set_element,
-    STATE_DICT_ITEM,
-)
-
+from ._traverse import OBJ_PATH, set_element, STATE_DICT_ITEM, traverse_state_dict
 from .utils import _element_wise_add, _normalize_device_info
 
 
@@ -62,9 +45,7 @@
             return
 
         if len(inner_st.local_shards()) != 1:
-            raise ValueError(
-                "Cannot handle inner tensor with more than 1 shard"
-            )
+            raise ValueError("Cannot handle inner tensor with more than 1 shard")
         inner_shard = inner_st.local_shards()[0]
 
         local_shards = [
diff --git a/torch/distributed/checkpoint/_traverse.py b/torch/distributed/checkpoint/_traverse.py
index 1367dc2..604b5e1 100644
--- a/torch/distributed/checkpoint/_traverse.py
+++ b/torch/distributed/checkpoint/_traverse.py
@@ -1,8 +1,7 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates
-import torch
-
 from typing import (
     Callable,
+    cast,
     Collection,
     List,
     Mapping,
@@ -11,13 +10,12 @@
     Tuple,
     TypeVar,
     Union,
-    cast,
 )
-from torch.distributed.checkpoint.metadata import (
-    STATE_DICT_TYPE,
-)
+
+import torch
 from torch.distributed._shard.sharded_tensor.api import ShardedTensor
 from torch.distributed._tensor import DTensor
+from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
 
 PATH_ITEM = Union[str, int]
 OBJ_PATH = Tuple[PATH_ITEM, ...]
@@ -47,6 +45,7 @@
     By default, all collections with at least one ``torch.Tensor`` element are traversed.
     Visitor takes a path argument that is a tuple of the keys used to reach it.
     """
+
     # a value is terminal if it has no other containers values inside it
     def _is_terminal(value: STATE_DICT_ITEM) -> bool:
         values: Collection[STATE_DICT_ITEM]
diff --git a/torch/distributed/checkpoint/api.py b/torch/distributed/checkpoint/api.py
index 8ccd740..8286851 100644
--- a/torch/distributed/checkpoint/api.py
+++ b/torch/distributed/checkpoint/api.py
@@ -1,5 +1,5 @@
-from typing import Dict, Tuple, Any
 import traceback as tb
+from typing import Any, Dict, Tuple
 
 WRAPPED_EXCEPTION = Tuple[BaseException, tb.StackSummary]
 
@@ -15,9 +15,7 @@
         return False
     if len(obj) != 2:
         return False
-    return isinstance(obj[0], BaseException) and isinstance(
-        obj[1], tb.StackSummary
-    )
+    return isinstance(obj[0], BaseException) and isinstance(obj[1], tb.StackSummary)
 
 
 class CheckpointException(BaseException):
diff --git a/torch/distributed/checkpoint/default_planner.py b/torch/distributed/checkpoint/default_planner.py
index 9bf74ab..775d5b8 100644
--- a/torch/distributed/checkpoint/default_planner.py
+++ b/torch/distributed/checkpoint/default_planner.py
@@ -6,50 +6,42 @@
 import operator
 from collections import ChainMap
 from functools import reduce
-from typing import List, Tuple, Dict, Any, Union, cast
+from typing import Any, cast, Dict, List, Tuple, Union
 
 import torch
-
 from torch.distributed._shard._utils import narrow_tensor_by_index
 from torch.distributed._tensor import DTensor
-
-
-from torch.distributed.checkpoint.planner import (
-    SavePlanner,
-    LoadPlanner,
-    SavePlan,
-    LoadPlan,
-    ReadItem,
-    WriteItem,
-    WriteItemType,
-)
-
-from torch.distributed.checkpoint.metadata import (
-    BytesStorageMetadata,
-    ChunkStorageMetadata,
-    TensorStorageMetadata,
-    MetadataIndex,
-    Metadata,
-    STATE_DICT_TYPE,
-    STORAGE_TYPES,
-)
-
-from torch.distributed.checkpoint.planner_helpers import (
-    _create_read_items,
-    _create_write_items,
-    _create_default_metadata_only_plan,
-)
-
+from torch.distributed.checkpoint._dedup_tensors import dedup_tensors
 from torch.distributed.checkpoint._nested_dict import (
     FLATTEN_MAPPING,
     flatten_state_dict,
 )
-from torch.distributed.checkpoint._sharded_tensor_utils import (
-    _flatten_sharded_tensors,
-)
-from torch.distributed.checkpoint._dedup_tensors import dedup_tensors
-from torch.distributed.checkpoint.utils import find_state_dict_object
+from torch.distributed.checkpoint._sharded_tensor_utils import _flatten_sharded_tensors
 from torch.distributed.checkpoint._traverse import set_element
+from torch.distributed.checkpoint.metadata import (
+    BytesStorageMetadata,
+    ChunkStorageMetadata,
+    Metadata,
+    MetadataIndex,
+    STATE_DICT_TYPE,
+    STORAGE_TYPES,
+    TensorStorageMetadata,
+)
+from torch.distributed.checkpoint.planner import (
+    LoadPlan,
+    LoadPlanner,
+    ReadItem,
+    SavePlan,
+    SavePlanner,
+    WriteItem,
+    WriteItemType,
+)
+from torch.distributed.checkpoint.planner_helpers import (
+    _create_default_metadata_only_plan,
+    _create_read_items,
+    _create_write_items,
+)
+from torch.distributed.checkpoint.utils import find_state_dict_object
 
 logger: logging.Logger = logging.getLogger(__name__)
 
@@ -79,9 +71,7 @@
         self.dedup_replicated_tensors = dedup_replicated_tensors
         self.mappings = {}
 
-    def set_up_planner(
-        self, state_dict: STATE_DICT_TYPE, is_coordinator: bool
-    ) -> None:
+    def set_up_planner(self, state_dict: STATE_DICT_TYPE, is_coordinator: bool) -> None:
         if self.flatten_state_dict:
             state_dict, self.mappings = flatten_state_dict(state_dict)
         if self.flatten_sharded_tensors:
@@ -90,9 +80,7 @@
         self.is_coordinator = is_coordinator
 
     def create_local_plan(self) -> SavePlan:
-        plan = create_default_local_save_plan(
-            self.state_dict, self.is_coordinator
-        )
+        plan = create_default_local_save_plan(self.state_dict, self.is_coordinator)
         if self.flatten_state_dict:
             plan = dataclasses.replace(plan, planner_data=self.mappings)
         self.plan = plan
@@ -114,9 +102,7 @@
             # )
             planner_data_dict = [p.planner_data for p in global_plan]
             merged_mappings = dict(ChainMap(*planner_data_dict))
-            metadata = dataclasses.replace(
-                metadata, planner_data=merged_mappings
-            )
+            metadata = dataclasses.replace(metadata, planner_data=merged_mappings)
 
         if not _validate_global_plan(global_plan, metadata):
             raise ValueError("Failed to validate global plan")
@@ -130,9 +116,7 @@
         self.plan = new_plan
         return new_plan
 
-    def resolve_data(
-        self, write_item: WriteItem
-    ) -> Union[torch.Tensor, io.BytesIO]:
+    def resolve_data(self, write_item: WriteItem) -> Union[torch.Tensor, io.BytesIO]:
         object = self.lookup_object(write_item.index)
         return self.transform_object(write_item, object)
 
@@ -222,9 +206,7 @@
 
     def transform_tensor(self, read_item: ReadItem, tensor: torch.Tensor):
         """Extension from the planner interface to make it easy to extend the default planner."""
-        return narrow_tensor_by_index(
-            tensor, read_item.dest_offsets, read_item.lengths
-        )
+        return narrow_tensor_by_index(tensor, read_item.dest_offsets, read_item.lengths)
 
 
 def create_default_local_load_plan(
@@ -353,9 +335,7 @@
     return md
 
 
-def _check_box_overlap(
-    box0: ChunkStorageMetadata, box1: ChunkStorageMetadata
-) -> bool:
+def _check_box_overlap(box0: ChunkStorageMetadata, box1: ChunkStorageMetadata) -> bool:
     """Check if two boxes overlap. Tuples are (offset, lengths)."""
     # For each dim of each shard, check if one shard resides on the other
     # end of second shard with respect to that dim. As an example for a 2D
@@ -385,9 +365,7 @@
     return True
 
 
-def _validate_global_plan(
-    global_plan: List[SavePlan], metadata: Metadata
-) -> bool:
+def _validate_global_plan(global_plan: List[SavePlan], metadata: Metadata) -> bool:
     all_good = True
     for key, value in metadata.state_dict_metadata.items():
         if isinstance(value, BytesStorageMetadata):
@@ -402,7 +380,10 @@
                     """
                         key:%s has out of bounds chunk:
                         tensor-size:%s chunk: %s
-                    """, key, value.size, chunk0
+                    """,
+                    key,
+                    value.size,
+                    chunk0,
                 )
                 all_good = False
             chunks_volume += reduce(operator.mul, chunk0.sizes, 1)
@@ -422,7 +403,10 @@
                 """
                     key:%s invalid fill tensor-volume:
                     %s chunks-volume: %s
-                """, key, tensor_volume, chunks_volume
+                """,
+                key,
+                tensor_volume,
+                chunks_volume,
             )
             all_good = False
 
diff --git a/torch/distributed/checkpoint/examples/fsdp_checkpoint_example.py b/torch/distributed/checkpoint/examples/fsdp_checkpoint_example.py
index 347f979..9e2438c 100644
--- a/torch/distributed/checkpoint/examples/fsdp_checkpoint_example.py
+++ b/torch/distributed/checkpoint/examples/fsdp_checkpoint_example.py
@@ -14,12 +14,10 @@
 import torch.distributed as dist
 import torch.distributed.checkpoint as dist_cp
 import torch.multiprocessing as mp
+from torch.distributed.checkpoint.optimizer import load_sharded_optimizer_state_dict
 
 from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
 from torch.distributed.fsdp.fully_sharded_data_parallel import StateDictType
-from torch.distributed.checkpoint.optimizer import (
-    load_sharded_optimizer_state_dict,
-)
 
 CHECKPOINT_DIR = f"/scratch/{os.environ['LOGNAME']}/checkpoint"
 
diff --git a/torch/distributed/checkpoint/filesystem.py b/torch/distributed/checkpoint/filesystem.py
index 6f7c710..8ee2a67 100644
--- a/torch/distributed/checkpoint/filesystem.py
+++ b/torch/distributed/checkpoint/filesystem.py
@@ -1,53 +1,38 @@
-from abc import ABC, abstractmethod
-import queue
-import threading
 import collections
-
-from dataclasses import dataclass
-import os
 import dataclasses
 import io
+import os
 import pickle
-from typing import Optional, List, Union, Dict, cast
+import queue
+import threading
+from abc import ABC, abstractmethod
+from dataclasses import dataclass
+from pathlib import Path
+from typing import cast, Dict, List, Optional, Union
 
 import torch
+import torch.distributed as dist
 from torch import Tensor
+from torch._utils import _get_device_module
+from torch.distributed._shard._utils import narrow_tensor_by_index
+from torch.distributed.checkpoint.checkpointer import Checkpointer
 from torch.futures import Future
-from pathlib import Path
 
-from .metadata import (
-    Metadata,
-    MetadataIndex,
-)
-from .storage import (
-    StorageReader,
-    StorageWriter,
-    WriteResult,
-)
-
+from .metadata import Metadata, MetadataIndex
 from .planner import (
     LoadItemType,
-    LoadPlanner,
     LoadPlan,
+    LoadPlanner,
+    ReadItem,
     SavePlan,
     SavePlanner,
-    ReadItem,
     WriteItem,
     WriteItemType,
 )
-
+from .storage import StorageReader, StorageWriter, WriteResult
 from .utils import _create_file_view
 
-import torch.distributed as dist
-from torch.distributed.checkpoint.checkpointer import Checkpointer
-from torch.distributed._shard._utils import narrow_tensor_by_index
-from torch._utils import _get_device_module
-
-__all__ = [
-    "FileSystemWriter",
-    "FileSystemReader",
-    "FileSystemCheckpointer"
-]
+__all__ = ["FileSystemWriter", "FileSystemReader", "FileSystemCheckpointer"]
 
 
 @dataclass
@@ -150,10 +135,7 @@
 
     def _refill(self):
         with self.device_module.stream(self.stream):
-            while (
-                not self._done
-                and self.in_flight_data < self.inflight_threshhold
-            ):
+            while not self._done and self.in_flight_data < self.inflight_threshhold:
                 _, obj = self.items[self.idx]
                 self.idx += 1
                 tensor = self.resolve_fun(obj).detach()
@@ -211,9 +193,7 @@
     return size * torch._utils._element_size(dtype)
 
 
-def _split_by_size_and_type(
-    bins, items: List[WriteItem]
-) -> List[List[WriteItem]]:
+def _split_by_size_and_type(bins, items: List[WriteItem]) -> List[List[WriteItem]]:
     if bins == 1:
         return [items]
 
@@ -276,16 +256,12 @@
                     planner.resolve_data,
                 )
 
-            tensor_w = [
-                wi for wi in write_items if wi.type != WriteItemType.BYTE_IO
-            ]
+            tensor_w = [wi for wi in write_items if wi.type != WriteItemType.BYTE_IO]
             for write_item in tensor_w:
                 loader.add(_item_size(write_item), write_item)
             loader.start_loading()
 
-            bytes_w = [
-                wi for wi in write_items if wi.type == WriteItemType.BYTE_IO
-            ]
+            bytes_w = [wi for wi in write_items if wi.type == WriteItemType.BYTE_IO]
             write_results = []
 
             with file_name.open("wb") as stream:
@@ -358,9 +334,7 @@
         self.path.mkdir(parents=True, exist_ok=True)
         return plan
 
-    def prepare_global_plan(
-        self, global_plan: List[SavePlan]
-    ) -> List[SavePlan]:
+    def prepare_global_plan(self, global_plan: List[SavePlan]) -> List[SavePlan]:
         new_plans = [
             dataclasses.replace(plan, storage_data=_StoragePrefix(f"__{i}_"))
             for i, plan in enumerate(global_plan)
@@ -383,9 +357,7 @@
 
         file_queue: queue.Queue = queue.Queue()
         if self.single_file_per_rank:
-            for bucket in _split_by_size_and_type(
-                self.thread_count, plan.items
-            ):
+            for bucket in _split_by_size_and_type(self.thread_count, plan.items):
                 file_name = gen_file()
                 file_queue.put((self.path / file_name, file_name, bucket))
         else:
@@ -432,9 +404,7 @@
             fut.set_result(res)
             return fut
 
-    def finish(
-        self, metadata: Metadata, results: List[List[WriteResult]]
-    ) -> None:
+    def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
         storage_md = dict()
         for wr_list in results:
             storage_md.update({wr.index: wr.storage_data for wr in wr_list})
@@ -507,11 +477,10 @@
     def prepare_local_plan(self, plan: LoadPlan) -> LoadPlan:
         return plan
 
-    def prepare_global_plan(
-        self, global_plan: List[LoadPlan]
-    ) -> List[LoadPlan]:
+    def prepare_global_plan(self, global_plan: List[LoadPlan]) -> List[LoadPlan]:
         return global_plan
 
+
 class FileSystemCheckpointer(Checkpointer):
     """An implementation of :py:class:`torch.distributed.checkpoint.checkpointer.Checkpointer`
     for the file system. Wraps the creation and usage of ``FileSystemWriter`` and ``FileSystemReader``.
@@ -547,11 +516,7 @@
         """
 
         storage_writer = FileSystemWriter(
-            path,
-            single_file_per_rank,
-            sync_files,
-            thread_count,
-            per_thread_copy_ahead
+            path, single_file_per_rank, sync_files, thread_count, per_thread_copy_ahead
         )
         storage_reader = FileSystemReader(path)
 
@@ -562,5 +527,5 @@
             coordinator_rank=coordinator_rank,
             no_dist=no_dist,
             load_planner=load_planner,
-            save_planner=save_planner
+            save_planner=save_planner,
         )
diff --git a/torch/distributed/checkpoint/metadata.py b/torch/distributed/checkpoint/metadata.py
index 8477e46..4ce6250 100644
--- a/torch/distributed/checkpoint/metadata.py
+++ b/torch/distributed/checkpoint/metadata.py
@@ -1,12 +1,10 @@
 from dataclasses import dataclass, field
-from typing import Dict, List, Union, Optional, Sequence, Any
-from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
-from torch.distributed.checkpoint.stateful import StatefulT
+from typing import Any, Dict, List, Optional, Sequence, Union
 
 import torch
-from torch.distributed._shard.sharded_tensor import (
-    ShardedTensor,
-)
+from torch.distributed._shard.sharded_tensor import ShardedTensor
+from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
+from torch.distributed.checkpoint.stateful import StatefulT
 
 __all__ = [
     "ChunkStorageMetadata",
diff --git a/torch/distributed/checkpoint/optimizer.py b/torch/distributed/checkpoint/optimizer.py
index 5e66f1e..7dd932b 100644
--- a/torch/distributed/checkpoint/optimizer.py
+++ b/torch/distributed/checkpoint/optimizer.py
@@ -1,49 +1,41 @@
 # Copyright (c) Meta Platforms, Inc. and affiliates
 
 import dataclasses
-from typing import Dict, List, Optional, Sequence, Tuple, Union, cast
-from torch.distributed.checkpoint.planner import LoadPlan
+from typing import cast, Dict, List, Optional, Sequence, Tuple, Union
 
 import torch
 import torch.distributed as dist
+from torch._utils import _get_device_module
 from torch.distributed._shard.sharded_tensor.api import ShardedTensor
 from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
 from torch.distributed._shard.sharded_tensor.shard import Shard
-from torch.distributed._shard.sharding_spec.chunk_sharding_spec import (
-    ChunkShardingSpec,
-)
-
-import torch.distributed.checkpoint as dist_cp
+from torch.distributed._shard.sharding_spec.chunk_sharding_spec import ChunkShardingSpec
+from torch.distributed._tensor import DTensor
+from torch.distributed.checkpoint._nested_dict import unflatten_state_dict
+from torch.distributed.checkpoint.default_planner import DefaultLoadPlanner
 from torch.distributed.checkpoint.metadata import (
     BytesStorageMetadata,
+    ChunkStorageMetadata,
     Metadata,
     MetadataIndex,
     STATE_DICT_TYPE,
     TensorStorageMetadata,
-    ChunkStorageMetadata,
 )
-from torch.distributed.distributed_c10d import _get_default_group
-from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor
+from torch.distributed.checkpoint.planner import LoadPlan, LoadPlanner
 from torch.distributed.checkpoint.planner_helpers import (
-    create_read_items_for_chunk_list,
     _create_read_items,
+    create_read_items_for_chunk_list,
 )
-from torch.distributed.remote_device import _remote_device
-
-from torch.distributed._tensor import DTensor
-from torch.distributed.checkpoint.default_planner import (
-    DefaultLoadPlanner,
-)
-from torch.distributed.checkpoint.planner import LoadPlanner
-
-from torch.distributed.checkpoint._nested_dict import unflatten_state_dict
+from torch.distributed.checkpoint.state_dict_loader import load_state_dict
+from torch.distributed.checkpoint.storage import StorageReader
 from torch.distributed.checkpoint.utils import (
     _element_wise_add,
     _element_wise_sub,
-    _normalize_device_info
+    _normalize_device_info,
 )
-
-from torch._utils import _get_device_module
+from torch.distributed.distributed_c10d import _get_default_group
+from torch.distributed.fsdp._shard_utils import _create_chunk_sharded_tensor
+from torch.distributed.remote_device import _remote_device
 
 STATE_DICT_2D_LAYOUT = Dict[str, Tuple[Optional[Sequence[int]], Sequence[int]]]
 
@@ -59,7 +51,9 @@
         return "cpu"
     device_module = _get_device_module(device_type)
     if device_module.is_available():
-        return _normalize_device_info(device_type, global_rank % device_module.device_count())
+        return _normalize_device_info(
+            device_type, global_rank % device_module.device_count()
+        )
     return "cpu"
 
 
@@ -90,18 +84,17 @@
         if type(val.local_shards()[0].tensor) is ShardedTensor:
             return True
         if type(val.local_shards()[0].tensor) is DTensor:
-            raise ValueError(
-                "Cannot handle DTensor nested insided ShardedTensor"
-            )
+            raise ValueError("Cannot handle DTensor nested insided ShardedTensor")
     elif type(val) is DTensor and (
-        type(val._local_tensor) is DTensor
-        or type(val._local_tensor) is ShardedTensor
+        type(val._local_tensor) is DTensor or type(val._local_tensor) is ShardedTensor
     ):
         raise ValueError("Cannot handle nested DTensor")
     return False
 
 
-def _alloc_tensor(props: TensorProperties, size: Sequence[int], device_type: str = "cuda") -> torch.Tensor:
+def _alloc_tensor(
+    props: TensorProperties, size: Sequence[int], device_type: str = "cuda"
+) -> torch.Tensor:
     return torch.empty(
         size=size,
         dtype=props.dtype,
@@ -181,9 +174,7 @@
             local_chunks = [
                 ChunkStorageMetadata(
                     offsets=torch.Size(
-                        _element_wise_add(
-                            original_shard.metadata.shard_offsets, offset
-                        )
+                        _element_wise_add(original_shard.metadata.shard_offsets, offset)
                     ),
                     sizes=torch.Size(original_shard.metadata.shard_sizes),
                 )
@@ -196,9 +187,7 @@
             # TODO: we should change _create_sharded_read_items to have more ergonomic API
             for ri in reqs:
                 assert ri.dest_index.offset is not None
-                original_offset = _element_wise_sub(
-                    ri.dest_index.offset, offset
-                )
+                original_offset = _element_wise_sub(ri.dest_index.offset, offset)
                 original_index = dataclasses.replace(
                     ri.dest_index, offset=torch.Size(original_offset)
                 )
@@ -214,7 +203,7 @@
 def load_sharded_optimizer_state_dict(
     model_state_dict: STATE_DICT_TYPE,
     optimizer_key: str,
-    storage_reader: dist_cp.StorageReader,
+    storage_reader: StorageReader,
     planner: Optional[LoadPlanner] = None,
 ) -> STATE_DICT_TYPE:
     """
@@ -273,7 +262,9 @@
     if dp_pg is None:
         placements = []
         for i in range(dist.get_world_size()):
-            device_info = _normalize_device_info(dp_pg_device_type, i % device_module.device_count())
+            device_info = _normalize_device_info(
+                dp_pg_device_type, i % device_module.device_count()
+            )
             placements.append(f"rank:{i}/{device_info}")
         sharding_spec = ChunkShardingSpec(dim=0, placements=placements)  # type: ignore[arg-type]
     else:
@@ -294,7 +285,9 @@
 
         # value: TensorStorageMetadata
         if value.size.numel() == 1:
-            state_dict[key] = _alloc_tensor(value.properties, value.size, dp_pg_device_type)
+            state_dict[key] = _alloc_tensor(
+                value.properties, value.size, dp_pg_device_type
+            )
         elif dp_pg is None:
             state_dict[key] = _create_chunk_sharded_tensor(
                 _alloc_tensor(value.properties, value.size, dp_pg_device_type),
@@ -313,10 +306,7 @@
             local_shards = []
             current_rank = dist.get_rank(dp_pg)
             for shard_md in st_md.shards_metadata:
-                if (
-                    cast(_remote_device, shard_md.placement).rank()
-                    != current_rank
-                ):
+                if cast(_remote_device, shard_md.placement).rank() != current_rank:
                     continue
                 local_shards.append(
                     Shard(
@@ -331,18 +321,13 @@
                 local_shards, st_md, process_group=dp_pg
             )
 
-            if (
-                spec_key in layout_specs
-                and layout_specs[spec_key][0] is not None
-            ):
-                fqn_to_offset[key] = cast(
-                    Sequence[int], layout_specs[spec_key][0]
-                )
+            if spec_key in layout_specs and layout_specs[spec_key][0] is not None:
+                fqn_to_offset[key] = cast(Sequence[int], layout_specs[spec_key][0])
 
             state_dict[key] = st
 
     # Whether we unflatten before or after doesn't matter
-    dist_cp.load_state_dict(
+    load_state_dict(
         state_dict=state_dict,
         storage_reader=storage_reader,
         # FIXME the type of planner is wrong in load_state_dict
diff --git a/torch/distributed/checkpoint/planner.py b/torch/distributed/checkpoint/planner.py
index b4aed46..ebcdfb0 100644
--- a/torch/distributed/checkpoint/planner.py
+++ b/torch/distributed/checkpoint/planner.py
@@ -1,19 +1,13 @@
 import abc
-from dataclasses import dataclass
 import io
-from typing import List, Tuple, Any, Union, Optional
+from dataclasses import dataclass
+from enum import auto, Enum
+from typing import Any, List, Optional, Tuple, Union
 
-from enum import Enum, auto
 import torch
-
 from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
 
-from .metadata import (
-    ChunkStorageMetadata,
-    MetadataIndex,
-    Metadata,
-    STATE_DICT_TYPE,
-)
+from .metadata import ChunkStorageMetadata, Metadata, MetadataIndex, STATE_DICT_TYPE
 
 
 __all__ = [
@@ -223,9 +217,7 @@
         pass
 
     @abc.abstractmethod
-    def resolve_data(
-        self, write_item: WriteItem
-    ) -> Union[torch.Tensor, io.BytesIO]:
+    def resolve_data(self, write_item: WriteItem) -> Union[torch.Tensor, io.BytesIO]:
         """
         Transform and prepare ``write_item`` from ``state_dict`` for storage, ensuring idempotency and thread-safety.
 
diff --git a/torch/distributed/checkpoint/planner_helpers.py b/torch/distributed/checkpoint/planner_helpers.py
index e9cde82..c01ea0d 100644
--- a/torch/distributed/checkpoint/planner_helpers.py
+++ b/torch/distributed/checkpoint/planner_helpers.py
@@ -1,7 +1,6 @@
 from typing import Any, List
 
 import torch
-
 from torch.distributed._shard.metadata import ShardMetadata
 from torch.distributed._shard.sharded_tensor import ShardedTensor
 from torch.distributed._shard.sharded_tensor.metadata import TensorProperties
@@ -16,7 +15,6 @@
     STORAGE_TYPES,
     TensorStorageMetadata,
 )
-
 from .planner import (
     LoadItemType,
     ReadItem,
@@ -25,7 +23,6 @@
     WriteItem,
     WriteItemType,
 )
-
 from .resharding import (
     _check_shard_metadata_pair_overlap,
     _shards_get_overlap_region_wrt_saved_tensor,
diff --git a/torch/distributed/checkpoint/resharding.py b/torch/distributed/checkpoint/resharding.py
index 8753bbc..1ebb0ba 100644
--- a/torch/distributed/checkpoint/resharding.py
+++ b/torch/distributed/checkpoint/resharding.py
@@ -1,12 +1,13 @@
 from typing import List, Tuple
 
-from torch.distributed.checkpoint.metadata import (
-    ChunkStorageMetadata
-)
+from torch.distributed.checkpoint.metadata import ChunkStorageMetadata
 
 __all__: List[str] = []
 
-def _check_shard_metadata_pair_overlap(shard1: ChunkStorageMetadata, shard2: ChunkStorageMetadata):
+
+def _check_shard_metadata_pair_overlap(
+    shard1: ChunkStorageMetadata, shard2: ChunkStorageMetadata
+):
     """Check if two shards overlap."""
     # For each dim of each shard, check if one shard resides on the other
     # end of second shard with respect to that dim. As an example for a 2D
@@ -21,6 +22,7 @@
 
     return True
 
+
 def _shards_get_overlap_region_wrt_saved_tensor(
     saved_shard: ChunkStorageMetadata, current_shard: ChunkStorageMetadata
 ) -> List[Tuple[int, int, int, int]]:
@@ -56,9 +58,7 @@
 
         if saved_shard_offset > current_shard_offset:
             offset_for_saved_tensor = 0
-            offset_for_current_tensor = (
-                saved_shard_offset - current_shard_offset
-            )
+            offset_for_current_tensor = saved_shard_offset - current_shard_offset
         else:
             offset_for_saved_tensor = current_shard_offset - saved_shard_offset
             offset_for_current_tensor = 0
diff --git a/torch/distributed/checkpoint/state_dict_loader.py b/torch/distributed/checkpoint/state_dict_loader.py
index a9291ae..57941e0 100644
--- a/torch/distributed/checkpoint/state_dict_loader.py
+++ b/torch/distributed/checkpoint/state_dict_loader.py
@@ -1,20 +1,18 @@
-from typing import Any, Dict, Optional
 import warnings
+from typing import Any, Dict, Optional
 
 import torch
 import torch.distributed as dist
 from torch.distributed.checkpoint.stateful import Stateful
 
-from .storage import (
-    StorageReader,
-)
-from .planner import LoadPlanner
 from .default_planner import DefaultLoadPlanner
-
-from .utils import _DistWrapper, _all_gather_keys
+from .planner import LoadPlanner
+from .storage import StorageReader
+from .utils import _all_gather_keys, _DistWrapper
 
 __all__ = ["load_state_dict", "load"]
 
+
 def load_state_dict(
     state_dict: Dict[str, Any],
     storage_reader: StorageReader,
@@ -28,7 +26,10 @@
         "'load_state_dict' is deprecated and will be removed in future versions. Please use 'load' instead."
     )
     # TODO: test returning `load` here instead.
-    return _load_state_dict(state_dict, storage_reader, process_group, coordinator_rank, no_dist, planner)
+    return _load_state_dict(
+        state_dict, storage_reader, process_group, coordinator_rank, no_dist, planner
+    )
+
 
 def load(
     state_dict: Dict[str, Any],
@@ -124,7 +125,9 @@
         elem = state_dict[key]
         statetful_sd[key] = elem.state_dict() if isinstance(elem, Stateful) else elem
 
-    _load_state_dict(statetful_sd, storage_reader, process_group, coordinator_rank, no_dist, planner)
+    _load_state_dict(
+        statetful_sd, storage_reader, process_group, coordinator_rank, no_dist, planner
+    )
     for key in keys:
         if key not in state_dict:
             continue
@@ -133,6 +136,7 @@
             elem.load_state_dict(statetful_sd[key])
         state_dict[key] = elem
 
+
 def _load_state_dict(
     state_dict: Dict[str, Any],
     storage_reader: StorageReader,
@@ -141,7 +145,6 @@
     no_dist: bool = False,
     planner: Optional[LoadPlanner] = None,
 ) -> None:
-
     torch._C._log_api_usage_once("torch.distributed.checkpoint.load_state_dict")
 
     distW = _DistWrapper(process_group, not no_dist, coordinator_rank)
diff --git a/torch/distributed/checkpoint/state_dict_saver.py b/torch/distributed/checkpoint/state_dict_saver.py
index 2fba073..f672aa3 100644
--- a/torch/distributed/checkpoint/state_dict_saver.py
+++ b/torch/distributed/checkpoint/state_dict_saver.py
@@ -1,18 +1,14 @@
-from typing import Optional
 import warnings
+from typing import Optional
 
 import torch
 import torch.distributed as dist
 from torch.distributed.checkpoint.stateful import Stateful
-from .planner import SavePlanner
+
 from .default_planner import DefaultSavePlanner
-
-
-from .storage import (
-    StorageWriter,
-)
-
 from .metadata import Metadata, STATE_DICT_TYPE
+from .planner import SavePlanner
+from .storage import StorageWriter
 from .utils import _DistWrapper
 
 __all__ = ["save_state_dict", "save"]
@@ -32,7 +28,10 @@
     )
 
     # TODO: test returning `save` here instead.
-    return _save_state_dict(state_dict, storage_writer, process_group, coordinator_rank, no_dist, planner)
+    return _save_state_dict(
+        state_dict, storage_writer, process_group, coordinator_rank, no_dist, planner
+    )
+
 
 def save(
     state_dict: STATE_DICT_TYPE,
@@ -108,7 +107,9 @@
 
     dumpable_state_dict = {}
     for key, elem in state_dict.items():
-        dumpable_state_dict[key] = elem.state_dict() if isinstance(elem, Stateful) else elem
+        dumpable_state_dict[key] = (
+            elem.state_dict() if isinstance(elem, Stateful) else elem
+        )
 
     return _save_state_dict(
         dumpable_state_dict,
@@ -116,9 +117,10 @@
         process_group,
         coordinator_rank,
         no_dist,
-        planner
+        planner,
     )
 
+
 def _save_state_dict(
     state_dict: STATE_DICT_TYPE,
     storage_writer: StorageWriter,
@@ -127,7 +129,6 @@
     no_dist: bool = False,
     planner: Optional[SavePlanner] = None,
 ) -> Metadata:
-
     torch._C._log_api_usage_once("torch.distributed.checkpoint.save_state_dict")
 
     distW = _DistWrapper(process_group, not no_dist, coordinator_rank)
@@ -149,9 +150,7 @@
         nonlocal global_metatadata
 
         assert planner is not None
-        all_local_plans, global_metatadata = planner.create_global_plan(
-            all_local_plans
-        )
+        all_local_plans, global_metatadata = planner.create_global_plan(all_local_plans)
         all_local_plans = storage_writer.prepare_global_plan(all_local_plans)
         return all_local_plans
 
diff --git a/torch/distributed/checkpoint/storage.py b/torch/distributed/checkpoint/storage.py
index a15db14..59282df 100644
--- a/torch/distributed/checkpoint/storage.py
+++ b/torch/distributed/checkpoint/storage.py
@@ -1,20 +1,11 @@
 import abc
 from dataclasses import dataclass
-from typing import List, Any
+from typing import Any, List
 
 from torch.futures import Future
 
-from .metadata import (
-    Metadata,
-    MetadataIndex,
-)
-
-from .planner import (
-    LoadPlan,
-    SavePlan,
-    SavePlanner,
-    LoadPlanner,
-)
+from .metadata import Metadata, MetadataIndex
+from .planner import LoadPlan, LoadPlanner, SavePlan, SavePlanner
 
 __all__ = ["WriteResult", "StorageWriter", "StorageReader"]
 
@@ -115,9 +106,7 @@
         pass
 
     @abc.abstractmethod
-    def finish(
-        self, metadata: Metadata, results: List[List[WriteResult]]
-    ) -> None:
+    def finish(self, metadata: Metadata, results: List[List[WriteResult]]) -> None:
         """
         Write the metadata and marks the current checkpoint as successful.
 
diff --git a/torch/distributed/checkpoint/utils.py b/torch/distributed/checkpoint/utils.py
index 25f96f1..b8dadf3 100644
--- a/torch/distributed/checkpoint/utils.py
+++ b/torch/distributed/checkpoint/utils.py
@@ -1,37 +1,21 @@
-import os
 import io
 import itertools
-from typing import (
-    List,
-    Callable,
-    Optional,
-    Union,
-    TypeVar,
-    Dict,
-    Any,
-    cast,
-    Sequence
-)
-import torch.distributed as dist
-from .api import (
-    CheckpointException,
-    _wrap_exception,
-    _is_wrapped_exception,
-    WRAPPED_EXCEPTION,
-)
+import os
+from typing import Any, Callable, cast, Dict, List, Optional, Sequence, TypeVar, Union
 
 import torch
-
-from torch.distributed._shard.sharded_tensor import (
-    ShardedTensor,
-)
+import torch.distributed as dist
+from torch.distributed._shard.sharded_tensor import ShardedTensor
 from torch.distributed._shard.sharded_tensor.shard import Shard
 from torch.distributed._tensor import DTensor
 
-from .metadata import (
-    STATE_DICT_TYPE,
-    MetadataIndex,
+from .api import (
+    _is_wrapped_exception,
+    _wrap_exception,
+    CheckpointException,
+    WRAPPED_EXCEPTION,
 )
+from .metadata import MetadataIndex, STATE_DICT_TYPE
 
 __all__ = ["find_tensor_shard", "find_state_dict_object"]
 
@@ -47,6 +31,7 @@
         {i: err for i, err in enumerate(results) if _is_wrapped_exception(err)},
     )
 
+
 def _all_gather_keys(local_dict: Dict[Any, Any]) -> List[Any]:
     """Gathers all keys, and returns them sorted."""
     keys = list(local_dict.keys())
@@ -55,6 +40,7 @@
     dist.all_gather_object(gathered_keys, keys)
     return sorted(set(itertools.chain.from_iterable(gathered_keys)))
 
+
 class _DistWrapper:
     """
     This is a wrapper around PG that provides a series of features around object collectives.
@@ -123,9 +109,7 @@
     def all_gather_object(self, object: T) -> List[T]:
         """Implement functionality similar to c10d::all_gather_object but without distributed enabled."""
         if self.use_dist:
-            gather_objs = cast(
-                List[T], [None] * dist.get_world_size(self.group)
-            )
+            gather_objs = cast(List[T], [None] * dist.get_world_size(self.group))
 
             dist.all_gather_object(
                 object_list=gather_objs, obj=object, group=self.group
@@ -140,9 +124,7 @@
             gather_result = cast(List[T], [None])
             dist.scatter_object_list(
                 scatter_object_output_list=gather_result,
-                scatter_object_input_list=object_list
-                if self.is_coordinator
-                else None,
+                scatter_object_input_list=object_list if self.is_coordinator else None,
                 src=self.coordinator_rank,
                 group=self.group,
             )
@@ -282,9 +264,7 @@
             try:
                 result = map_fun()
             except BaseException as e:
-                result = CheckpointException(
-                    step, {self.rank: _wrap_exception(e)}
-                )
+                result = CheckpointException(step, {self.rank: _wrap_exception(e)})
         final_result = self.broadcast_object(result)
         if isinstance(final_result, CheckpointException):
             raise final_result
@@ -302,22 +282,17 @@
     if index.index is not None:
         if (
             len(shards) > index.index
-            and torch.Size(shards[index.index].metadata.shard_offsets)
-            == index.offset
+            and torch.Size(shards[index.index].metadata.shard_offsets) == index.offset
         ):
             return shards[index.index]
 
     for shard in shards:
         if torch.Size(shard.metadata.shard_offsets) == index.offset:
             return shard
-    raise ValueError(
-        f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'"
-    )
+    raise ValueError(f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'")
 
 
-def find_tensor_shard(
-    tensor: torch.Tensor, index: MetadataIndex
-) -> torch.Tensor:
+def find_tensor_shard(tensor: torch.Tensor, index: MetadataIndex) -> torch.Tensor:
     if isinstance(tensor, DTensor):
         return tensor.to_local()
     if isinstance(tensor, ShardedTensor):
@@ -332,9 +307,7 @@
     return tensor
 
 
-def find_state_dict_object(
-    state_dict: STATE_DICT_TYPE, index: MetadataIndex
-) -> Any:
+def find_state_dict_object(state_dict: STATE_DICT_TYPE, index: MetadataIndex) -> Any:
     if index.fqn not in state_dict:
         raise ValueError(f"Could not find FQN: '{index.fqn}'")
     obj = state_dict[index.fqn]