Fix RPC get_worker_info for rank=0 (#52804)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/52804

`rpc.get_worker_info` used to only take string in v1.6. We recently
allow it to accept `int` and `WorkerInfo`, but the previous check
on `worker_name` is no longer correct. This commit adds explicit
`not None` check.

Test Plan: Imported from OSS

Reviewed By: rohan-varma

Differential Revision: D26655089

Pulled By: mrshenli

fbshipit-source-id: fa1545bd6dd2b33bc1e919de46b94e799ab9719c
diff --git a/torch/distributed/rpc/api.py b/torch/distributed/rpc/api.py
index 453ac40..7124c43 100644
--- a/torch/distributed/rpc/api.py
+++ b/torch/distributed/rpc/api.py
@@ -347,7 +347,7 @@
         ``worker_name`` or :class:`~torch.distributed.rpc.WorkerInfo` of the
         current worker if ``worker_name`` is ``None``.
     """
-    if worker_name:
+    if worker_name is not None:
         return _get_current_rpc_agent().get_worker_info(worker_name)
     else:
         return _get_current_rpc_agent().get_worker_info()
diff --git a/torch/testing/_internal/distributed/rpc/rpc_test.py b/torch/testing/_internal/distributed/rpc/rpc_test.py
index ee7a54c..f0cdb27 100644
--- a/torch/testing/_internal/distributed/rpc/rpc_test.py
+++ b/torch/testing/_internal/distributed/rpc/rpc_test.py
@@ -885,6 +885,16 @@
         )
         self.assertEqual(ret, torch.ones(n, n) * 2)
 
+    @staticmethod
+    def return_callee_id():
+        return rpc.get_worker_info().id
+
+    @dist_init
+    def test_int_callee(self):
+        dst_rank = (self.rank + 1) % self.world_size
+        ret = rpc.rpc_sync(dst_rank, RpcTest.return_callee_id)
+        self.assertEqual(ret, dst_rank)
+
     @dist_init
     def test_add_with_id(self):
         n = self.rank + 1