Split deserialize from _run_function in RPC internal.py (#34494)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/34494
Differential Revision: D20347463
Test Plan: Imported from OSS
Pulled By: mrshenli
fbshipit-source-id: e6fd886622f26c46bb83ac118e67abb2f5b296b9
diff --git a/torch/csrc/distributed/rpc/python_rpc_handler.cpp b/torch/csrc/distributed/rpc/python_rpc_handler.cpp
index 8a2879a..2800ab6 100644
--- a/torch/csrc/distributed/rpc/python_rpc_handler.cpp
+++ b/torch/csrc/distributed/rpc/python_rpc_handler.cpp
@@ -95,18 +95,17 @@
const SerializedPyObj& serializedPyObj,
std::vector<torch::Tensor>& responseTensorTable) {
PROFILE_GIL_SCOPED_ACQUIRE;
- auto pargs = py::bytes(serializedPyObj.payload_);
- py::tuple pres =
- pySerialize_(pyRunFunction_(pargs, serializedPyObj.tensors_));
+ auto pythonUdf = deserialize(serializedPyObj);
+ py::tuple pres = pySerialize_(pyRunFunction_(std::move(pythonUdf)));
responseTensorTable = pres[1].cast<std::vector<torch::Tensor>>();
return pres[0].cast<std::string>();
}
py::object PythonRpcHandler::runPythonUDF(
- const SerializedPyObj& serializedObj) {
+ const SerializedPyObj& serializedPyObj) {
PROFILE_GIL_SCOPED_ACQUIRE;
- return pyRunFunction_(
- py::bytes(serializedObj.payload_), serializedObj.tensors_);
+ auto pythonUdf = deserialize(serializedPyObj);
+ return pyRunFunction_(std::move(pythonUdf));
}
SerializedPyObj PythonRpcHandler::serialize(const py::object& obj) {
@@ -118,6 +117,9 @@
py::object PythonRpcHandler::deserialize(const SerializedPyObj& serializedObj) {
PROFILE_GIL_SCOPED_ACQUIRE;
+ // NB: pyDeserialize_ can return an AttributeError if the deserialize() Python
+ // function fails. Functions consuming the result needs to handle such error
+ // properly.
return pyDeserialize_(
py::bytes(serializedObj.payload_), serializedObj.tensors_);
}
diff --git a/torch/distributed/rpc/internal.py b/torch/distributed/rpc/internal.py
index fb4b6d7..fdd661f 100644
--- a/torch/distributed/rpc/internal.py
+++ b/torch/distributed/rpc/internal.py
@@ -98,7 +98,7 @@
except_str = str(e) + """ Default RPC pickler does not serialize
function code. Ensure that UDFs are defined on both caller and
callee modules."""
- raise AttributeError(except_str)
+ ret = AttributeError(except_str)
# restore _thread_local_tensor_tables.recv_tables if return
# from nested call, otherwise clean up the table
@@ -122,7 +122,7 @@
return _internal_rpc_pickler.deserialize(binary_data, tensor_table)
-def _run_function(binary_data, tensor_table):
+def _run_function(python_udf):
r"""
This function is exclusively called from C++.
See ``torch/csrc/distributed/rpc/python_rpc_handler.cpp``.
@@ -131,7 +131,8 @@
Wraps any exception in ``RemoteException`` if the function raises.
"""
try:
- python_udf = _internal_rpc_pickler.deserialize(binary_data, tensor_table)
+ if isinstance(python_udf, AttributeError):
+ raise python_udf
result = python_udf.func(*python_udf.args, **python_udf.kwargs)
except Exception as e:
# except str = exception info + traceback string