|  | from typing import Optional, Union | 
|  |  | 
|  | import torch | 
|  |  | 
|  |  | 
|  | class _remote_device(object): | 
|  | """ | 
|  | Represents a device on a remote worker. | 
|  |  | 
|  | Args: | 
|  | remote_device (str or torch.device): Represents a device on a remote worker. | 
|  | The string format should be one of the following: | 
|  |  | 
|  | 1. "<workername>/<device>", where the device field can be parsed as torch.device type. | 
|  | E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0". | 
|  | In addition, the device field can be optional and the default value is "cpu". | 
|  | 2. "rank:<rank>/<device>", where <rank> is the rank of the | 
|  | process and device can be parsed as torch.device type. | 
|  | E.g., "rank:0/cpu", "rank:0", "rank:0/cuda:0" | 
|  | 3. <workername> and <rank> are optional and formats like "cpu" | 
|  | and "cuda:1", just represent local devices. | 
|  | """ | 
|  |  | 
|  | def __init__(self, remote_device: Union[str, torch.device]): | 
|  | PARSE_ERROR = ( | 
|  | f"Could not parse remote_device: {remote_device}. The valid format is " | 
|  | "'<workername>/<device>' or 'rank:<rank>/<device>' or '<device>'" | 
|  | ) | 
|  | self._worker_name = None | 
|  | self._rank = None | 
|  | self._device: Optional[Union[str, int, torch.device]] = None | 
|  |  | 
|  | if isinstance(remote_device, torch.device): | 
|  | self._device = remote_device | 
|  | elif isinstance(remote_device, str): | 
|  | fields = remote_device.split("/") | 
|  | if len(fields) == 2: | 
|  | self._worker_name, self._device = fields | 
|  | elif len(fields) == 1: | 
|  | # Check if this is a valid device. | 
|  | if _remote_device._is_valid_local_device(fields[0]): | 
|  | self._device = fields[0] | 
|  | else: | 
|  | self._worker_name = fields[0] | 
|  | self._device = "cpu" | 
|  | else: | 
|  | raise ValueError(PARSE_ERROR) | 
|  | else: | 
|  | raise TypeError(f'Invalid type for remote_device: {type(remote_device)}') | 
|  |  | 
|  | # Do some basic sanity check (no empty string) | 
|  | if self._worker_name is not None and not self._worker_name: | 
|  | raise ValueError(PARSE_ERROR) | 
|  |  | 
|  | # Validate the device. | 
|  | self._device = torch.device(self._device) | 
|  |  | 
|  | # Check for rank based format. | 
|  | if self._worker_name is not None: | 
|  | fields = self._worker_name.split(":") | 
|  | if len(fields) == 2: | 
|  | # rank:<rank>/device format, extract rank | 
|  | if fields[0] == "rank" and fields[1].isdigit(): | 
|  | self._rank = int(fields[1])  # type: ignore[assignment] | 
|  | self._worker_name = None | 
|  | else: | 
|  | raise ValueError(PARSE_ERROR) | 
|  | elif len(fields) > 2: | 
|  | raise ValueError(PARSE_ERROR) | 
|  |  | 
|  | @staticmethod | 
|  | def _is_valid_local_device(device): | 
|  | # Check for torch.device | 
|  | try: | 
|  | torch.device(device) | 
|  | return True | 
|  | except Exception: | 
|  | return False | 
|  |  | 
|  | def worker_name(self) -> Optional[str]: | 
|  | """ | 
|  | Returns the name of remote worker representing the remote device. | 
|  | Returns ``None`` if no worker name is available. | 
|  | """ | 
|  | return self._worker_name | 
|  |  | 
|  | def rank(self) -> Optional[int]: | 
|  | """ | 
|  | Returns the rank of remote worker representing the remote device. | 
|  | Returns ``None`` if no rank is available. | 
|  | """ | 
|  | return self._rank | 
|  |  | 
|  | def device(self) -> torch.device: | 
|  | """ | 
|  | Returns the local device on the remote worker. | 
|  | """ | 
|  | return self._device  # type: ignore[return-value] | 
|  |  | 
|  | def __repr__(self): | 
|  | if self._device is not None: | 
|  | if self._worker_name is not None: | 
|  | return f'{self._worker_name}/{self._device}' | 
|  | elif self._rank is not None: | 
|  | return f'rank:{self._rank}/{self._device}' | 
|  | else: | 
|  | return str(self._device) | 
|  | else: | 
|  | if self._worker_name is not None: | 
|  | return f'{self._worker_name}' | 
|  | elif self._rank is not None: | 
|  | return f'{self._rank}' | 
|  | else: | 
|  | raise RuntimeError('Invalid state!') | 
|  |  | 
|  | def __eq__(self, other): | 
|  | if not isinstance(other, _remote_device): | 
|  | return False | 
|  |  | 
|  | if ( | 
|  | self._worker_name == other._worker_name | 
|  | and self._device == other._device | 
|  | and self._rank == other._rank | 
|  | ): | 
|  | return True | 
|  |  | 
|  | return False |