blob: e0e9e0a1aa83807ff7bdd89407e1580f54b020d8 [file] [log] [blame]
import collections
import copyreg
import io
import pickle
import threading
import traceback
import torch
# Thread local tensor tables to store tensors while pickling torch.Tensor
# objects
_thread_local_tensor_tables = threading.local()
class _InternalRPCPickler:
r"""
This class provides serialize() and deserialize() interfaces to serialize
data to be "binary string + tensor table" format
So for RPC python UDF function and args, non tensor data will be serialized
into regular binary string, tensor data will be put into thread local tensor
tables, this serialization format is consistent with builtin operator and args
using JIT pickler. This format will make tensor handling in C++ much easier,
e.g. attach tensor to distributed autograd graph in C++
"""
def __init__(self):
# python2 does not have dispatch_table, add "if torch._six.PY3" condition,
# as _InternalRPCPickler still got build in python2 even
# we skipped python 2 tests for rpc_test
if torch._six.PY3:
self._dispatch_table = copyreg.dispatch_table.copy()
self._dispatch_table[torch.Tensor] = self._tensor_reducer
@classmethod
def _tensor_receiver(cls, tensor_index):
global _thread_local_tensor_tables
return _thread_local_tensor_tables.recv_tables[tensor_index]
def _tensor_reducer(self, obj):
global _thread_local_tensor_tables
_thread_local_tensor_tables.send_tables.append(obj)
tensor_index = len(_thread_local_tensor_tables.send_tables) - 1
return (_InternalRPCPickler._tensor_receiver, (tensor_index, ))
def serialize(self, obj):
r"""
Serialize non tensor data into binary string, tensor data into
tensor table
"""
f = io.BytesIO()
p = pickle.Pickler(f)
p.dispatch_table = self._dispatch_table
# save _thread_local_tensor_tables.send_tables if it is in nested call
global _thread_local_tensor_tables
if hasattr(_thread_local_tensor_tables, "send_tables"):
old_send_tables = _thread_local_tensor_tables.send_tables
else:
old_send_tables = None
_thread_local_tensor_tables.send_tables = []
p.dump(obj)
# restore _thread_local_tensor_tables.send_tables if return
# from nested call, otherwise clean up the table
tensors = _thread_local_tensor_tables.send_tables
if old_send_tables is not None:
_thread_local_tensor_tables.send_tables = old_send_tables
else:
del _thread_local_tensor_tables.send_tables
return (f.getvalue(), tensors)
def deserialize(self, binary_data, tensor_table):
r"""
Deserilize binary string + tensor table to original obj
"""
# save _thread_local_tensor_tables.recv_tables if it is in nested call
global _thread_local_tensor_tables
if hasattr(_thread_local_tensor_tables, "recv_tables"):
old_recv_tables = _thread_local_tensor_tables.recv_tables
else:
old_recv_tables = None
_thread_local_tensor_tables.recv_tables = tensor_table
ret = pickle.loads(binary_data)
# restore _thread_local_tensor_tables.recv_tables if return
# from nested call, otherwise clean up the table
if old_recv_tables is not None:
_thread_local_tensor_tables.recv_tables = old_recv_tables
else:
del _thread_local_tensor_tables.recv_tables
return ret
# Create _internal_rpc_pickler only once to initialize _dispatch_table only once
_internal_rpc_pickler = _InternalRPCPickler()
def serialize(obj):
return _internal_rpc_pickler.serialize(obj)
def _run_function(binary_data, tensor_table):
r"""
This function is exclusively called from C++.
See ``torch/csrc/distributed/rpc/python_rpc_handler.cpp``.
Runs a Python UDF and returns its return value.
Wraps any exception in ``RemoteException`` if the function raises.
"""
python_udf = _internal_rpc_pickler.deserialize(binary_data, tensor_table)
try:
result = python_udf.func(*python_udf.args, **python_udf.kwargs)
except Exception as e:
# except str = exception info + traceback string
except_str = "{}\n{}".format(repr(e), traceback.format_exc())
result = RemoteException(except_str)
return result
def _load_return_value(binary_data, tensor_table):
r"""
This function is exclusively called from C++.
See ``torch/csrc/distributed/rpc/python_rpc_handler.cpp``.
Processes the return value of a Python function.
Raises exception if the return value is a wrapped exception.
"""
result = _internal_rpc_pickler.deserialize(binary_data, tensor_table)
if isinstance(result, RemoteException):
raise Exception(result.msg)
return result
PythonUDF = collections.namedtuple("PythonUDF", ["func", "args", "kwargs"])
RemoteException = collections.namedtuple("RemoteException", ["msg"])