Split deserialize from _run_function in RPC internal.py (#34494)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/34494

Differential Revision: D20347463

Test Plan: Imported from OSS

Pulled By: mrshenli

fbshipit-source-id: e6fd886622f26c46bb83ac118e67abb2f5b296b9
diff --git a/torch/csrc/distributed/rpc/python_rpc_handler.cpp b/torch/csrc/distributed/rpc/python_rpc_handler.cpp
index 8a2879a..2800ab6 100644
--- a/torch/csrc/distributed/rpc/python_rpc_handler.cpp
+++ b/torch/csrc/distributed/rpc/python_rpc_handler.cpp
@@ -95,18 +95,17 @@
     const SerializedPyObj& serializedPyObj,
     std::vector<torch::Tensor>& responseTensorTable) {
   PROFILE_GIL_SCOPED_ACQUIRE;
-  auto pargs = py::bytes(serializedPyObj.payload_);
-  py::tuple pres =
-      pySerialize_(pyRunFunction_(pargs, serializedPyObj.tensors_));
+  auto pythonUdf = deserialize(serializedPyObj);
+  py::tuple pres = pySerialize_(pyRunFunction_(std::move(pythonUdf)));
   responseTensorTable = pres[1].cast<std::vector<torch::Tensor>>();
   return pres[0].cast<std::string>();
 }
 
 py::object PythonRpcHandler::runPythonUDF(
-    const SerializedPyObj& serializedObj) {
+    const SerializedPyObj& serializedPyObj) {
   PROFILE_GIL_SCOPED_ACQUIRE;
-  return pyRunFunction_(
-      py::bytes(serializedObj.payload_), serializedObj.tensors_);
+  auto pythonUdf = deserialize(serializedPyObj);
+  return pyRunFunction_(std::move(pythonUdf));
 }
 
 SerializedPyObj PythonRpcHandler::serialize(const py::object& obj) {
@@ -118,6 +117,9 @@
 
 py::object PythonRpcHandler::deserialize(const SerializedPyObj& serializedObj) {
   PROFILE_GIL_SCOPED_ACQUIRE;
+  // NB: pyDeserialize_ can return an AttributeError if the deserialize() Python
+  // function fails. Functions consuming the result needs to handle such error
+  // properly.
   return pyDeserialize_(
       py::bytes(serializedObj.payload_), serializedObj.tensors_);
 }
diff --git a/torch/distributed/rpc/internal.py b/torch/distributed/rpc/internal.py
index fb4b6d7..fdd661f 100644
--- a/torch/distributed/rpc/internal.py
+++ b/torch/distributed/rpc/internal.py
@@ -98,7 +98,7 @@
             except_str = str(e) + """ Default RPC pickler does not serialize
             function code. Ensure that UDFs are defined on both caller and
             callee modules."""
-            raise AttributeError(except_str)
+            ret = AttributeError(except_str)
 
         # restore _thread_local_tensor_tables.recv_tables if return
         # from nested call, otherwise clean up the table
@@ -122,7 +122,7 @@
     return _internal_rpc_pickler.deserialize(binary_data, tensor_table)
 
 
-def _run_function(binary_data, tensor_table):
+def _run_function(python_udf):
     r"""
     This function is exclusively called from C++.
     See ``torch/csrc/distributed/rpc/python_rpc_handler.cpp``.
@@ -131,7 +131,8 @@
     Wraps any exception in ``RemoteException`` if the function raises.
     """
     try:
-        python_udf = _internal_rpc_pickler.deserialize(binary_data, tensor_table)
+        if isinstance(python_udf, AttributeError):
+            raise python_udf
         result = python_udf.func(*python_udf.args, **python_udf.kwargs)
     except Exception as e:
         # except str = exception info + traceback string