Throw an actionable error message on user call rref<ScriptModule>.to_here() in torchscript (#35369)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/35369
For issue, https://github.com/pytorch/pytorch/issues/35367
Test Plan:
```
buck build mode/dev-nosan //caffe2/test/distributed/rpc/jit:rpc_fork \
&& buck-out/gen/caffe2/test/distributed/rpc/jit/rpc_fork\#binary.par \
-r test_remote_script_module
buck build mode/dev-nosan //caffe2/test/distributed/rpc/jit:rpc_fork \
&& buck-out/gen/caffe2/test/distributed/rpc/jit/rpc_fork\#binary.par \
-r test_torchscript_functions_not_supported
```
Differential Revision: D7870906
fbshipit-source-id: 2e78f2e620a5cc7c8f26ab35400ba33bb303788d
diff --git a/torch/csrc/distributed/rpc/rref_impl.cpp b/torch/csrc/distributed/rpc/rref_impl.cpp
index f6fe2bd..42341be 100644
--- a/torch/csrc/distributed/rpc/rref_impl.cpp
+++ b/torch/csrc/distributed/rpc/rref_impl.cpp
@@ -115,6 +115,19 @@
" and ForkId=",
forkId(),
" has been deleted. Cannot call to_here() on it after deletion.");
+ TORCH_CHECK(
+ !type_->is_module(),
+ "User RRef with RRefId=",
+ rrefId(),
+ " and ForkId=",
+ forkId(),
+ " is an RRef to a ScriptModule. "
+ "It can't be sent through RPC "
+ "from owner, ",
+ ownerName(),
+ ", to user, ",
+ RpcAgent::getCurrentRpcAgent()->getWorkerInfo().name_,
+ ".");
auto agent = RpcAgent::getCurrentRpcAgent();
diff --git a/torch/testing/_internal/distributed/rpc/jit/rpc_test.py b/torch/testing/_internal/distributed/rpc/jit/rpc_test.py
index e5c75a7..7e95168 100644
--- a/torch/testing/_internal/distributed/rpc/jit/rpc_test.py
+++ b/torch/testing/_internal/distributed/rpc/jit/rpc_test.py
@@ -691,6 +691,16 @@
)
self.assertEqual(ret, local_ret)
+ # pass rref arg to self/user
+ with self.assertRaisesRegex(
+ RuntimeError, "is an RRef to a ScriptModule. It can't be sent through RPC from owner,"
+ ):
+ ret = rpc.rpc_sync(
+ worker_name(self.rank),
+ run_ref_script_module,
+ args=(remote_ref, torch.ones(self.rank)),
+ )
+
@dist_init
def test_rref_is_owner(self):
n = self.rank + 1