blob: 323b3608023abd4f267cfcc6f388a7c47d3287e9 [file] [log] [blame]
import torch
def _parse_remote_device(remote_device: str):
r"""
Parses the remote device.
Args:
remote_device (str): Device on the destination worker where we'd like to place this module.
The format should be one of the following:
1. "<workername>/<device>", where the device field can be parsed as torch.device type.
E.g., "trainer0/cpu", "trainer0", "ps0/cuda:0".
In addition, the device field can be optional and the default value is "cpu".
2. "rank:<rank>/<device>", where <rank> is the rank of the
process and device can be parsed as torch.device type.
E.g., "rank:0/cpu", "rank:0", "rank:0/cuda:0"
Returns:
A workername/rank and a device.
"""
PARSE_ERROR = (
f"Could not parse remote_device: {remote_device}. The valid format is "
"'<workername>/<device>' or 'rank:<rank>/<device>'"
)
fields = remote_device.split("/")
if len(fields) == 2:
[on, device] = fields
elif len(fields) == 1:
on = fields[0]
device = "cpu"
else:
raise ValueError(PARSE_ERROR)
# Since the workername in the input remote device won't be validated until the created remote module is executed,
# only do some very basic sanity check on workername at the module creation time.
# As currently there is no regex to describe the format of workername, just check whether the workername is empty.
if not on:
raise ValueError(PARSE_ERROR)
# Validate the device.
torch.device(device)
# Check for rank based format
fields = on.split(':')
if len(fields) == 2:
# rank:<rank>/device format, extract rank
if fields[0] == 'rank' and fields[1].isdigit():
on = int(fields[1]) # type: ignore[assignment]
else:
raise ValueError(PARSE_ERROR)
elif len(fields) > 2:
raise ValueError(PARSE_ERROR)
return on, device