Fix RPC get_worker_info for rank=0 (#52804)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/52804
`rpc.get_worker_info` used to only take string in v1.6. We recently
allow it to accept `int` and `WorkerInfo`, but the previous check
on `worker_name` is no longer correct. This commit adds explicit
`not None` check.
Test Plan: Imported from OSS
Reviewed By: rohan-varma
Differential Revision: D26655089
Pulled By: mrshenli
fbshipit-source-id: fa1545bd6dd2b33bc1e919de46b94e799ab9719c
diff --git a/torch/distributed/rpc/api.py b/torch/distributed/rpc/api.py
index 453ac40..7124c43 100644
--- a/torch/distributed/rpc/api.py
+++ b/torch/distributed/rpc/api.py
@@ -347,7 +347,7 @@
``worker_name`` or :class:`~torch.distributed.rpc.WorkerInfo` of the
current worker if ``worker_name`` is ``None``.
"""
- if worker_name:
+ if worker_name is not None:
return _get_current_rpc_agent().get_worker_info(worker_name)
else:
return _get_current_rpc_agent().get_worker_info()
diff --git a/torch/testing/_internal/distributed/rpc/rpc_test.py b/torch/testing/_internal/distributed/rpc/rpc_test.py
index ee7a54c..f0cdb27 100644
--- a/torch/testing/_internal/distributed/rpc/rpc_test.py
+++ b/torch/testing/_internal/distributed/rpc/rpc_test.py
@@ -885,6 +885,16 @@
)
self.assertEqual(ret, torch.ones(n, n) * 2)
+ @staticmethod
+ def return_callee_id():
+ return rpc.get_worker_info().id
+
+ @dist_init
+ def test_int_callee(self):
+ dst_rank = (self.rank + 1) % self.world_size
+ ret = rpc.rpc_sync(dst_rank, RpcTest.return_callee_id)
+ self.assertEqual(ret, dst_rank)
+
@dist_init
def test_add_with_id(self):
n = self.rank + 1