Enable custom device support in fsdp checkpoint (#107289)

Fixes https://github.com/pytorch/pytorch/issues/104390
Enable custom device(privateuse1 backend) support in checkpointing by a dynamic abstract device module.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/107289
Approved by: https://github.com/wz337
diff --git a/torch/_utils.py b/torch/_utils.py
index bcb2c3c..36142a6 100644
--- a/torch/_utils.py
+++ b/torch/_utils.py
@@ -1,4 +1,5 @@
 import copyreg
+import functools
 import sys
 import traceback
 import warnings
@@ -839,3 +840,13 @@
 # Whether we are compiling with torch.compile or not
 def is_compiling():
     return False
+
+
+@functools.lru_cache(2)
+def _get_device_module(device_type: str):
+    device_module = getattr(torch, device_type, None)
+    if device_module is None:
+        raise RuntimeError(
+            f"Device '{device_type}' does not have a corresponding module registered as 'torch.{device_type}'."
+        )
+    return device_module
diff --git a/torch/distributed/checkpoint/_fsspec_filesystem.py b/torch/distributed/checkpoint/_fsspec_filesystem.py
index b8d1c24..0d37924 100644
--- a/torch/distributed/checkpoint/_fsspec_filesystem.py
+++ b/torch/distributed/checkpoint/_fsspec_filesystem.py
@@ -18,6 +18,7 @@
 import torch
 from fsspec.core import url_to_fs
 from torch import Tensor
+from torch._utils import _get_device_module
 
 from torch.distributed._shard._utils import narrow_tensor_by_index
 from torch.distributed.checkpoint.metadata import Metadata, MetadataIndex
@@ -114,7 +115,7 @@
     def __init__(
         self,
         resolve_fun: Callable,
-        stream: Union[None, io.RawIOBase, torch._C._CudaStreamBase] = None,
+        stream: Union[None, io.RawIOBase, torch.Stream] = None,
         inflight_threshhold: int = 1_000_000,
     ):
         self.resolve_fun = resolve_fun
@@ -124,9 +125,11 @@
         self.current_items: collections.deque = collections.deque()
         self.idx = 0
         self.started = False
-        self.stream = stream or torch.cuda.current_stream()
-        if self.stream != torch.cuda.current_stream():
-            self.stream.wait_stream(torch.cuda.current_stream())
+        self.device_type = stream.device_type if stream else torch.device("cuda").type
+        self.device_module = _get_device_module(self.device_type)
+        self.stream = stream or self.device_module.current_stream()
+        if self.stream != self.device_module.current_stream():
+            self.stream.wait_stream(self.device_module.current_stream())
 
     @property
     def _done(self):
@@ -143,7 +146,7 @@
         return drained
 
     def _refill(self):
-        with torch.cuda.stream(self.stream):
+        with self.device_module.stream(self.stream):
             while (
                 not self._done
                 and self.in_flight_data < self.inflight_threshhold
@@ -151,7 +154,7 @@
                 _, obj = self.items[self.idx]
                 self.idx += 1
                 tensor = self.resolve_fun(obj).detach()
-                if tensor.is_cuda:
+                if tensor.device.type == self.device_type:
                     tensor = tensor.to(device="cpu", non_blocking=True)
                 elif tensor.device == torch.device("cpu"):
                     if tensor.storage().size() != tensor.numel():
@@ -232,7 +235,7 @@
 
 
 def _write_item(
-    stream: Optional[Union[io.RawIOBase, torch._C._CudaStreamBase]],
+    stream: Optional[Union[io.RawIOBase, torch.Stream]],
     data: Union[io.BytesIO, torch.Tensor],
     write_item: WriteItem,
     storage_key: str,
@@ -294,7 +297,7 @@
                     )
 
                 for tensor, write_item in loader.values():
-                    assert not tensor.is_cuda
+                    assert tensor.is_cpu
                     write_results.append(
                         _write_item(stream, tensor, write_item, storage_key)
                     )
diff --git a/torch/distributed/checkpoint/_sharded_tensor_utils.py b/torch/distributed/checkpoint/_sharded_tensor_utils.py
index 8d39be2..07bbdc9 100644
--- a/torch/distributed/checkpoint/_sharded_tensor_utils.py
+++ b/torch/distributed/checkpoint/_sharded_tensor_utils.py
@@ -26,7 +26,7 @@
     STATE_DICT_ITEM,
 )
 
-from .utils import _element_wise_add
+from .utils import _element_wise_add, _normalize_device_info
 
 
 # TODO: We need to refactor this code.
@@ -83,6 +83,7 @@
 
         st_meta: ShardedTensorMetadata = copy.deepcopy(value.metadata())
         other_rank = 0 if dist.get_rank() > 0 else 1
+        device_info = _normalize_device_info(inner_shard.tensor.device.type, 0)
 
         # Remove the outer ST shard the inner ST covers
         for i, shard_md in enumerate(st_meta.shards_metadata):
@@ -92,7 +93,7 @@
 
         # Attribute other rank for the other shards
         for shard_md in st_meta.shards_metadata:
-            shard_md.placement = _remote_device(f"rank:{other_rank}/cuda:0")
+            shard_md.placement = _remote_device(f"rank:{other_rank}/{device_info}")
 
         # Add other inner shards from the inner tensor
         for inner_md in inner_st.metadata().shards_metadata:
@@ -104,7 +105,7 @@
                             inner_md.shard_offsets,
                         ),
                         shard_sizes=inner_md.shard_sizes,
-                        placement=f"rank:{other_rank}/cuda:0",
+                        placement=f"rank:{other_rank}/{device_info}",
                     )
                 )
 
diff --git a/torch/distributed/checkpoint/filesystem.py b/torch/distributed/checkpoint/filesystem.py
index c6ef58c..d23bf27 100644
--- a/torch/distributed/checkpoint/filesystem.py
+++ b/torch/distributed/checkpoint/filesystem.py
@@ -39,6 +39,7 @@
 from .utils import _create_file_view
 
 from torch.distributed._shard._utils import narrow_tensor_by_index
+from torch._utils import _get_device_module
 
 __all__ = [
     "FileSystemWriter",
@@ -126,9 +127,11 @@
         self.current_items: collections.deque = collections.deque()
         self.idx = 0
         self.started = False
-        self.stream = stream or torch.cuda.current_stream()
-        if self.stream != torch.cuda.current_stream():
-            self.stream.wait_stream(torch.cuda.current_stream())
+        self.device_type = stream.device_type if stream else torch.device("cuda").type
+        self.device_module = _get_device_module(self.device_type)
+        self.stream = stream or self.device_module.current_stream()
+        if self.stream != self.device_module.current_stream():
+            self.stream.wait_stream(self.device_module.current_stream())
 
     @property
     def _done(self):
@@ -145,7 +148,7 @@
         return drained
 
     def _refill(self):
-        with torch.cuda.stream(self.stream):
+        with self.device_module.stream(self.stream):
             while (
                 not self._done
                 and self.in_flight_data < self.inflight_threshhold
@@ -153,7 +156,7 @@
                 _, obj = self.items[self.idx]
                 self.idx += 1
                 tensor = self.resolve_fun(obj).detach()
-                if tensor.is_cuda:
+                if tensor.device.type == self.device_type:
                     tensor = tensor.to(device="cpu", non_blocking=True)
                 elif tensor.device == torch.device("cpu"):
                     if tensor.storage().size() != tensor.numel():
@@ -292,7 +295,7 @@
                     )
 
                 for tensor, write_item in loader.values():
-                    assert not tensor.is_cuda
+                    assert tensor.is_cpu
                     write_results.append(
                         _write_item(stream, tensor, write_item, storage_key)
                     )
diff --git a/torch/distributed/checkpoint/optimizer.py b/torch/distributed/checkpoint/optimizer.py
index 67e4504..0d359aa 100644
--- a/torch/distributed/checkpoint/optimizer.py
+++ b/torch/distributed/checkpoint/optimizer.py
@@ -38,8 +38,11 @@
 from torch.distributed.checkpoint.utils import (
     _element_wise_add,
     _element_wise_sub,
+    _normalize_device_info
 )
 
+from torch._utils import _get_device_module
+
 STATE_DICT_2D_LAYOUT = Dict[str, Tuple[Optional[Sequence[int]], Sequence[int]]]
 
 
@@ -49,23 +52,27 @@
 ]
 
 
-def _gen_rank_device(global_rank: int) -> str:
-    if torch.cuda.is_available():
-        return f"cuda:{global_rank % torch.cuda.device_count()}"
+def _gen_rank_device(global_rank: int, device_type: str = "cuda") -> str:
+    if device_type == "cpu":
+        return "cpu"
+    device_module = _get_device_module(device_type)
+    if device_module.is_available():
+        return _normalize_device_info(device_type, global_rank % device_module.device_count())
     return "cpu"
 
 
 def _create_colwise_spec(
     pg: Optional[dist.ProcessGroup] = None,
 ) -> ChunkShardingSpec:
+    pg_device_type = dist.distributed_c10d._get_pg_default_device(pg).type
     if pg is None:
         placements = [
-            f"rank:{idx}/{_gen_rank_device(idx)}"
+            f"rank:{idx}/{_gen_rank_device(idx, pg_device_type)}"
             for idx in range(dist.get_world_size())
         ]
     else:
         placements = [
-            f"rank:{idx}/{_gen_rank_device(dist.get_global_rank(pg, idx))}"
+            f"rank:{idx}/{_gen_rank_device(dist.get_global_rank(pg, idx), pg_device_type)}"
             for idx in range(pg.size())
         ]
     return ChunkShardingSpec(
@@ -92,14 +99,14 @@
     return False
 
 
-def _alloc_tensor(props: TensorProperties, size: Sequence[int]) -> torch.Tensor:
+def _alloc_tensor(props: TensorProperties, size: Sequence[int], device_type: str = "cuda") -> torch.Tensor:
     return torch.empty(
         size=size,
         dtype=props.dtype,
         layout=props.layout,
         requires_grad=props.requires_grad,
         pin_memory=props.pin_memory,
-        device=cast(torch.device, torch.cuda.current_device()),
+        device=cast(torch.device, _get_device_module(device_type).current_device()),
     )
 
 
@@ -255,15 +262,15 @@
     metadata = storage_reader.read_metadata()
 
     layout_specs, dp_pg = _get_state_dict_2d_layout(model_state_dict)
+    dp_pg_device_type = dist.distributed_c10d._get_pg_default_device(dp_pg).type
+    device_module = _get_device_module(dp_pg_device_type)
 
     if dp_pg is None:
-        sharding_spec = ChunkShardingSpec(
-            dim=0,
-            placements=[
-                f"rank:{i}/cuda:{i % torch.cuda.device_count()}"
-                for i in range(dist.get_world_size())
-            ],
-        )
+        placements = []
+        for i in range(dist.get_world_size()):
+            device_info = _normalize_device_info(dp_pg_device_type, i % device_module.device_count())
+            placements.append(f"rank:{i}/{device_info}")
+        sharding_spec = ChunkShardingSpec(dim=0, placements=placements)  # type: ignore[arg-type]
     else:
         sharding_spec = _create_colwise_spec(dp_pg)
 
@@ -282,10 +289,10 @@
 
         # value: TensorStorageMetadata
         if value.size.numel() == 1:
-            state_dict[key] = _alloc_tensor(value.properties, value.size)
+            state_dict[key] = _alloc_tensor(value.properties, value.size, dp_pg_device_type)
         elif dp_pg is None:
             state_dict[key] = _shard_tensor(
-                _alloc_tensor(value.properties, value.size), sharding_spec
+                _alloc_tensor(value.properties, value.size, dp_pg_device_type), sharding_spec
             )
         else:
             spec_key = key_path[2]
@@ -305,7 +312,7 @@
                 local_shards.append(
                     Shard(
                         tensor=_alloc_tensor(
-                            value.properties, shard_md.shard_sizes
+                            value.properties, shard_md.shard_sizes, dp_pg_device_type
                         ),
                         metadata=shard_md,
                     )
diff --git a/torch/distributed/checkpoint/utils.py b/torch/distributed/checkpoint/utils.py
index 546b0bc..d110503 100644
--- a/torch/distributed/checkpoint/utils.py
+++ b/torch/distributed/checkpoint/utils.py
@@ -355,6 +355,7 @@
 def _element_wise_sub(a: Sequence[int], b: Sequence[int]) -> List[int]:
     return [i_a - i_b for i_a, i_b in zip(a, b)]
 
+
 class _ReaderView(io.IOBase):
     def __init__(self, base_stream: io.IOBase, offset: int, len: int):
         super().__init__()
@@ -386,6 +387,16 @@
     def read(self, size=-1):
         return self.base_stream.read(size)
 
+
 def _create_file_view(file: io.IOBase, offset: int, length: int) -> io.IOBase:
     # FIXME (kumpera) torch.load fails if we wrap with io.BufferedReader
     return _ReaderView(file, offset, length)
+
+
+def _normalize_device_info(device_type: str, device_id: int) -> str:
+    """
+    Device info normalization.
+    """
+    if device_type == "cpu":
+        return "cpu"
+    return f"{device_type}:{device_id}"
diff --git a/torch/distributed/utils.py b/torch/distributed/utils.py
index eb7dc5a..c7ab8b9 100644
--- a/torch/distributed/utils.py
+++ b/torch/distributed/utils.py
@@ -107,7 +107,7 @@
                 with device_mod.stream(stream):
                     output = obj.to(target_device)
                 # synchronize with the copy stream
-                with torch.cuda.device(target_device.index):
+                with device_mod.device(target_device.index):
                     current_stream = device_mod.current_stream()
                     # Sync the current stream with the copy stream
                     current_stream.wait_stream(stream)