Revert D18482934: support torch script call over rpc
Test Plan: revert-hammer
Differential Revision:
D18482934
Original commit changeset: bd82a0d820c4
fbshipit-source-id: ca5e50fb0a883ee311aeb310198d84ad28062158
diff --git a/test/dist_autograd_test.py b/test/dist_autograd_test.py
index 8eb78ad..0b0c1c4 100644
--- a/test/dist_autograd_test.py
+++ b/test/dist_autograd_test.py
@@ -79,9 +79,6 @@
ret = torch.add(rref_t1.local_value(), t2)
return ret
-@torch.jit.script
-def my_script_add(t1, t2):
- return torch.add(t1, t2)
def my_nested_rref_add(dst, rref_t1, t2):
return rpc.rpc_sync(dst, my_rref_add, args=(rref_t1, t2))
@@ -150,7 +147,6 @@
LOCAL = 1 # Run the operation locally.
RPC_SYNC = 2 # Run the operation using rpc_sync
REMOTE = 3 # Run the operation using remote.
- RPC_ASYNC = 4 # Run the operation using rpc_async
@unittest.skipIf(
not torch._six.PY3, "Pytorch distributed autograd package " "does not support python2"
@@ -168,10 +164,6 @@
elif ExecMode.REMOTE == exec_mode:
return rpc.remote('worker{}'.format(self._next_rank()), method,
args=(args)).to_here()
- elif ExecMode.RPC_ASYNC == exec_mode:
- fut = rpc.rpc_async('worker{}'.format(self._next_rank()), method,
- args=(args))
- return fut.wait()
else:
raise ValueError("Unrecognized ExecMode {}".format(exec_mode))
@@ -1177,20 +1169,6 @@
loss = ret.sum()
local_grads = self._verify_backwards(exec_mode, [loss], context_id, local_grads, t1, t2)
- @dist_init
- def test_backward_simple_script_call(self):
- # Run the same code locally and with dist autograd and verify gradients
- # are same.
- local_grads = None
- t1 = torch.rand((3, 3), requires_grad=True)
- t2 = torch.rand((3, 3), requires_grad=True)
- for exec_mode in [ExecMode.LOCAL, ExecMode.RPC_SYNC, ExecMode.RPC_ASYNC]:
- with dist_autograd.context() as context_id:
- ret = self._exec_func(exec_mode, my_script_add, t1, t2)
- loss = ret.sum()
- ret = self._verify_backwards(exec_mode, [loss], context_id, local_grads, t1, t2)
- local_grads = ret if ret else local_grads
-
@staticmethod
def _complex_python_udf(t1, t2):
t3 = torch.nn.functional.linear(t1, t2)
diff --git a/test/rpc_test.py b/test/rpc_test.py
index 6901cee..906d3d3 100644
--- a/test/rpc_test.py
+++ b/test/rpc_test.py
@@ -18,7 +18,6 @@
from torch.distributed.rpc.api import _use_rpc_pickler
from torch.distributed.rpc.internal import PythonUDF, _internal_rpc_pickler, RPCExecMode
from rpc_agent_test_fixture import RpcAgentTestFixture
-from torch._jit_internal import _qualified_name
def requires_process_group_agent(message=""):
@@ -226,23 +225,6 @@
global global_rref
global_rref = None
-
-@torch.jit.script
-class MyScriptClass:
- def __init__(self):
- self.a = 10
-
-
-class MyScriptModule(torch.jit.ScriptModule):
- def __init__(self):
- super().__init__()
- self.a = 10
-
- @torch.jit.script_method
- def my_method(self):
- self.a = 11
-
-
# load_tests from common_utils is used to automatically filter tests for
# sharding on sandcastle. This line silences flake warnings
load_tests = load_tests
@@ -730,51 +712,6 @@
fut.wait()
@dist_init
- def test_script_function_exception(self):
- @torch.jit.script
- def no_args():
- a = 1
- n = self.rank + 1
- dst_rank = n % self.world_size
- with self.assertRaisesRegex(Exception, "no_args"):
- ret = rpc.rpc_sync("worker{}".format(dst_rank), no_args, args=(10,))
-
- @dist_init
- def test_script_functions_not_supported(self):
- # Right now _rpc_sync_torchscript does not accept annotated torchscript
- # class name or script module class name or their class method names.
- # But rpc_sync still accepts script class name and run it in
- # the same code path as python call.
- # Currently neither rpc_sync or _rpc_sync_torchscript is allowed to
- # accept script module and script module method.
- n = self.rank + 1
- dst_rank = n % self.world_size
- with self.assertRaisesRegex(RuntimeError, "attempted to get undefined function"):
- ret = rpc._rpc_sync_torchscript(
- 'worker{}'.format(dst_rank),
- _qualified_name(MyScriptClass),
- args=())
- ret = rpc.rpc_sync(
- 'worker{}'.format(dst_rank), MyScriptClass, args=())
-
- with self.assertRaisesRegex(RuntimeError, "attempted to get undefined function"):
- ret = rpc._rpc_sync_torchscript(
- 'worker{}'.format(dst_rank),
- _qualified_name(MyScriptModule),
- args=())
-
- with self.assertRaisesRegex(RuntimeError, "attempted to get undefined function"):
- ret = rpc._rpc_sync_torchscript(
- 'worker{}'.format(dst_rank),
- _qualified_name(MyScriptModule().my_method),
- args=())
- with self.assertRaisesRegex(TypeError, "can't pickle"):
- ret = rpc.rpc_sync(
- 'worker{}'.format(dst_rank),
- MyScriptModule().my_method,
- args=())
-
- @dist_init
def test_nested_rpc(self):
n = self.rank + 1
dst_rank = n % self.world_size
diff --git a/tools/build_variables.py b/tools/build_variables.py
index ec208ea..9e450ff 100644
--- a/tools/build_variables.py
+++ b/tools/build_variables.py
@@ -308,7 +308,6 @@
"torch/csrc/distributed/rpc/request_callback_impl.cpp",
"torch/csrc/distributed/rpc/rref_context.cpp",
"torch/csrc/distributed/rpc/rref_impl.cpp",
- "torch/csrc/distributed/rpc/script_functions.cpp",
"torch/csrc/jit/init.cpp",
"torch/csrc/jit/passes/inline_fork_wait.cpp",
"torch/csrc/jit/passes/onnx.cpp",
diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt
index 25b5cfa..0fa95ff 100644
--- a/torch/CMakeLists.txt
+++ b/torch/CMakeLists.txt
@@ -247,7 +247,6 @@
${TORCH_SRC_DIR}/csrc/distributed/rpc/request_callback_impl.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/rref_context.cpp
${TORCH_SRC_DIR}/csrc/distributed/rpc/rref_impl.cpp
- ${TORCH_SRC_DIR}/csrc/distributed/rpc/script_functions.cpp
)
list(APPEND TORCH_PYTHON_LINK_LIBRARIES c10d)
list(APPEND TORCH_PYTHON_COMPILE_DEFINITIONS USE_C10D)
diff --git a/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp b/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp
index 8bb8397..5cbadff 100644
--- a/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp
+++ b/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp
@@ -137,7 +137,7 @@
if (originalMessageType == MessageType::FORWARD_AUTOGRAD_REQ) {
wrappedRpc = deserializeRequest(wrappedMessage);
} else {
- wrappedRpc = deserializeResponse(wrappedMessage, wrappedMessageType);
+ wrappedRpc = deserializeResponse(wrappedMessage);
}
return std::make_unique<RpcWithAutograd>(
@@ -162,11 +162,6 @@
return *wrappedRpc_;
}
-std::unique_ptr<RpcCommandBase> RpcWithAutograd::moveWrappedRpc() && {
- TORCH_INTERNAL_ASSERT(wrappedRpc_ != nullptr, "wrappedRpc cannot be null!");
- return std::move(wrappedRpc_);
-}
-
MessageType RpcWithAutograd::wrappedMessageType() const {
return wrappedMessageType_;
}
diff --git a/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h b/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h
index eb99d57..4f00e43 100644
--- a/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h
+++ b/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h
@@ -42,8 +42,6 @@
RpcCommandBase& wrappedRpc();
- std::unique_ptr<RpcCommandBase> moveWrappedRpc() &&;
-
// Message type of the wrapped RPC.
rpc::MessageType wrappedMessageType() const;
diff --git a/torch/csrc/distributed/rpc/init.cpp b/torch/csrc/distributed/rpc/init.cpp
index 56f455a..fae43dc 100644
--- a/torch/csrc/distributed/rpc/init.cpp
+++ b/torch/csrc/distributed/rpc/init.cpp
@@ -3,10 +3,8 @@
#include <torch/csrc/distributed/rpc/process_group_agent.h>
#include <torch/csrc/distributed/rpc/py_rref.h>
#include <torch/csrc/distributed/rpc/python_functions.h>
-#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
#include <torch/csrc/distributed/rpc/rpc_agent.h>
#include <torch/csrc/distributed/rpc/rref_context.h>
-#include <torch/csrc/distributed/rpc/script_functions.h>
#include <torch/csrc/distributed/rpc/types.h>
#include <torch/csrc/jit/pybind_utils.h>
#include <torch/csrc/utils/object_ptr.h>
@@ -292,54 +290,6 @@
py::arg("tensors"),
py::arg("rf") = nullptr);
- // TODO This python future wrapper wraps c10::ivalue::Future.
- // Will merge with JIT PythonFutureWrapper while merging generic Future with
- // c10::ivalue::Future later on.
- struct PythonFutureWrapper {
- explicit PythonFutureWrapper(c10::intrusive_ptr<c10::ivalue::Future> fut)
- : fut(std::move(fut)) {}
-
- c10::intrusive_ptr<c10::ivalue::Future> fut;
- };
-
- // Since FutureMessage is binded to Future, here we need to bind the
- // PythonFutureWrapper to a different name.
- // TODO Once python object can be tagged as IValue and c10::ivalue::Future is
- // implemented as generic Future<IValue>, we can consider all rpc call
- // to return a future<IValue> later on.
- shared_ptr_class_<PythonFutureWrapper>(module, "_pyFuture")
- .def("wait", [](PythonFutureWrapper& fut) {
- fut.fut->wait();
- auto res = fut.fut->value();
- {
- // acquiring GIL as torch::jit::toPyObject creates new py::object
- // without grabbing the GIL.
- AutoGIL ag;
- return torch::jit::toPyObject(std::move(res));
- }
- });
-
- module.def(
- "_invoke_rpc_script",
- [](const std::string& dst,
- const std::string& qualifiedName,
- const py::args& args,
- const py::kwargs& kwargs) {
- // No need to catch exception here, if function can not be found,
- // exception will be thrown in get_function() call; if args do not match
- // with function schema, exception will be thrown in
- // createStackForSchema() call.
- auto name = c10::QualifiedName(qualifiedName);
- auto fnSchema = PythonRpcHandler::getInstance()
- .jitCompilationUnit()
- ->get_function(name)
- .getSchema();
- auto stack = torch::jit::createStackForSchema(
- fnSchema, args, kwargs, c10::nullopt);
- auto fut = rpcTorchscriptCall(dst, name, stack);
- return PythonFutureWrapper(fut);
- });
-
module.def(
"_invoke_remote_builtin",
[](RpcAgent& agent,
diff --git a/torch/csrc/distributed/rpc/python_functions.cpp b/torch/csrc/distributed/rpc/python_functions.cpp
index eff3010..96bb2ea 100644
--- a/torch/csrc/distributed/rpc/python_functions.cpp
+++ b/torch/csrc/distributed/rpc/python_functions.cpp
@@ -124,6 +124,19 @@
return PythonRpcHandler::getInstance().loadPythonUDFResult(
resp.pickledPayload(), resp.tensors());
}
+ case MessageType::FORWARD_AUTOGRAD_RESP: {
+ auto& rpcWithAutograd = static_cast<RpcWithAutograd&>(rpc);
+
+ // Attach 'recv' autograd function.
+ addRecvRpcBackward(
+ rpcWithAutograd.autogradMetadata(),
+ rpcWithAutograd.tensors(),
+ rpcWithAutograd.fromWorkerId());
+
+ // Handle the original RPC.
+ auto wrappedMessageType = rpcWithAutograd.wrappedMessageType();
+ return toPyObjInternal(rpcWithAutograd.wrappedRpc(), wrappedMessageType);
+ }
default: {
TORCH_CHECK(false, "Unrecognized response message type ", messageType);
}
@@ -131,9 +144,7 @@
}
py::object toPyObj(const Message& message) {
- MessageType msgType = message.type();
- auto response = deserializeResponse(message, msgType);
- return toPyObjInternal(*response, msgType);
+ return toPyObjInternal(*deserializeResponse(message), message.type());
}
std::shared_ptr<FutureMessage> pyRpcBuiltin(
diff --git a/torch/csrc/distributed/rpc/python_rpc_handler.cpp b/torch/csrc/distributed/rpc/python_rpc_handler.cpp
index 1b8afab..7d9f21d 100644
--- a/torch/csrc/distributed/rpc/python_rpc_handler.cpp
+++ b/torch/csrc/distributed/rpc/python_rpc_handler.cpp
@@ -1,5 +1,4 @@
#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
-#include <torch/csrc/jit/pybind_utils.h>
namespace torch {
namespace distributed {
@@ -26,7 +25,6 @@
pyLoadReturnValue_ = getFunction(module, "_load_return_value");
pySerialize_ = getFunction(module, "serialize");
pyHandleException_ = getFunction(module, "_handle_exception");
- jitCompilationUnit_ = torch::jit::get_python_cu();
}
void PythonRpcHandler::cleanup() {
@@ -35,7 +33,6 @@
pyLoadReturnValue_ = py::none();
pySerialize_ = py::none();
pyHandleException_ = py::none();
- jitCompilationUnit_ = nullptr;
}
PythonRpcHandler& PythonRpcHandler::getInstance() {
@@ -43,11 +40,6 @@
return handler;
}
-std::shared_ptr<torch::jit::script::CompilationUnit> PythonRpcHandler::
- jitCompilationUnit() {
- return jitCompilationUnit_;
-}
-
std::vector<char> PythonRpcHandler::generatePythonUDFResult(
const std::vector<char>& pickledPayload,
const std::vector<torch::Tensor>& requestTensorTable,
diff --git a/torch/csrc/distributed/rpc/python_rpc_handler.h b/torch/csrc/distributed/rpc/python_rpc_handler.h
index 68c7bd8..9dddb24 100644
--- a/torch/csrc/distributed/rpc/python_rpc_handler.h
+++ b/torch/csrc/distributed/rpc/python_rpc_handler.h
@@ -57,8 +57,6 @@
// PythonRpcHandler.
void cleanup();
- std::shared_ptr<torch::jit::script::CompilationUnit> jitCompilationUnit();
-
private:
PythonRpcHandler();
~PythonRpcHandler() = default;
@@ -79,13 +77,6 @@
// Ref to 'torch.distributed.rpc.internal._handle_exception'
py::object pyHandleException_;
-
- // Shared ptr to python compilation unit in jit, it is constructed in python
- // side (see _python_cu = torch._C.CompilationUnit() in jit/__init__.py)
- // and imported in C++ (see get_python_cu() in csrc/jit/pybind_utils.h).
- // We import the compilation unit here only once for less cost and thread
- // safety.
- std::shared_ptr<torch::jit::script::CompilationUnit> jitCompilationUnit_;
};
} // namespace rpc
diff --git a/torch/csrc/distributed/rpc/request_callback_impl.cpp b/torch/csrc/distributed/rpc/request_callback_impl.cpp
index 95a39d0..4ad0e2e 100644
--- a/torch/csrc/distributed/rpc/request_callback_impl.cpp
+++ b/torch/csrc/distributed/rpc/request_callback_impl.cpp
@@ -49,15 +49,7 @@
// sc is only alive within this block, use reference to avoid copy
auto& stack = scriptCall.stackRef();
- at::IValue res;
- if (scriptCall.hasOp()) {
- scriptCall.op()->getOperation()(stack);
- } else {
- PythonRpcHandler::getInstance()
- .jitCompilationUnit()
- ->get_function(scriptCall.qualifiedName())
- .run(stack);
- }
+ scriptCall.op()->getOperation()(stack);
TORCH_INTERNAL_ASSERT(
stack.size() == 1,
diff --git a/torch/csrc/distributed/rpc/rref_impl.cpp b/torch/csrc/distributed/rpc/rref_impl.cpp
index af128c9..a1f582f 100644
--- a/torch/csrc/distributed/rpc/rref_impl.cpp
+++ b/torch/csrc/distributed/rpc/rref_impl.cpp
@@ -22,6 +22,27 @@
// NB: if more fields are added, make sure this field is also bumped
constexpr int RFD_TUPLE_SIZE = 6; // number of RRefForkData fields in py::tuple
+
+template <typename T>
+T& unwrapAutogradMessage(
+ const Message& message,
+ std::unique_ptr<RpcCommandBase>& response) {
+ if (message.type() == MessageType::FORWARD_AUTOGRAD_RESP) {
+ auto& rpcWithAutograd = static_cast<autograd::RpcWithAutograd&>(*response);
+
+ // Attach 'recv' autograd function.
+ addRecvRpcBackward(
+ rpcWithAutograd.autogradMetadata(),
+ rpcWithAutograd.tensors(),
+ rpcWithAutograd.fromWorkerId());
+
+ auto& wrappedRpc = rpcWithAutograd.wrappedRpc();
+ return static_cast<T&>(wrappedRpc);
+ } else {
+ return static_cast<T&>(*response);
+ }
+}
+
} // namespace
std::atomic<local_id_t> RRefContext::nextLocalId_{0};
@@ -121,13 +142,8 @@
true /* forceGradRecording */);
const Message& message = futureResponse->wait();
- MessageType msgType = message.type();
- auto response = deserializeResponse(message, msgType);
- TORCH_INTERNAL_ASSERT(
- msgType == MessageType::SCRIPT_RREF_FETCH_RET,
- "Message type should be SCRIPT_RREF_FETCH_RET.");
- RpcCommandBase& rpc = *response;
- auto& rfr = static_cast<ScriptRRefFetchRet&>(rpc);
+ auto response = deserializeResponse(message);
+ auto& rfr = unwrapAutogradMessage<ScriptRRefFetchRet>(message, response);
return rfr.values().front();
}
@@ -145,13 +161,8 @@
true /* forceGradRecording */);
const Message& message = futureResponse->wait();
- MessageType msgType = message.type();
- auto response = deserializeResponse(message, msgType);
- TORCH_INTERNAL_ASSERT(
- msgType == MessageType::PYTHON_RREF_FETCH_RET,
- "Message type should be PYTHON_RREF_FETCH_RET.");
- RpcCommandBase& rpc = *response;
- auto& rfr = static_cast<PythonRRefFetchRet&>(rpc);
+ auto response = deserializeResponse(message);
+ auto& rfr = unwrapAutogradMessage<PythonRRefFetchRet>(message, response);
return PythonRpcHandler::getInstance().deserialize(
SerializedPyObj::fromIValues(rfr.values()));
}
diff --git a/torch/csrc/distributed/rpc/script_call.cpp b/torch/csrc/distributed/rpc/script_call.cpp
index 737f989..1d241e0 100644
--- a/torch/csrc/distributed/rpc/script_call.cpp
+++ b/torch/csrc/distributed/rpc/script_call.cpp
@@ -13,27 +13,10 @@
std::vector<at::IValue>&& args)
: op_(std::move(op)), stack_(args) {}
-ScriptCall::ScriptCall(
- const c10::QualifiedName& qualifiedName,
- std::vector<at::IValue>&& args)
- : qualifiedName_(qualifiedName), stack_(args) {}
-
-bool ScriptCall::hasOp() const {
- return op_ ? true : false;
-}
-
std::shared_ptr<Operator> ScriptCall::op() const {
return *op_;
}
-bool ScriptCall::hasQualifiedName() const {
- return qualifiedName_ ? true : false;
-}
-
-const c10::QualifiedName ScriptCall::qualifiedName() const {
- return *qualifiedName_;
-}
-
const std::vector<at::IValue>& ScriptCall::stack() const {
return stack_;
}
@@ -47,10 +30,7 @@
ivalues.push_back(value);
}
- if (hasOp()) {
- TORCH_CHECK(
- !hasQualifiedName(),
- "It is builtin operator call, qualifiedName_ should not be set.");
+ if (op_) {
// TODO: replace this with a real overload_name when FunctionSchema supports
// that.
ivalues.emplace_back(toString((*op_)->schema()));
@@ -64,24 +44,11 @@
// aten::add -> torch.ops.aten.add
opName.replace(0, ATEN_PREFIX_.length(), BUILTIN_OP_NAMESPACE_);
ivalues.emplace_back(std::move(opName));
- } else if (hasQualifiedName()) {
- TORCH_CHECK(
- !hasOp(),
- "It is TorchScript function call, operator should not be set.");
- ivalues.emplace_back((*qualifiedName_).qualifiedName());
- } else {
- TORCH_INTERNAL_ASSERT(
- false,
- "Either builtin operator or TorchScript function name should be set.");
}
}
-std::unique_ptr<ScriptCall> ScriptCall::fromIValues(
+std::shared_ptr<Operator> ScriptCall::fromIValues(
std::vector<at::IValue>& ivalues) {
- // Last element in the vector is always qualifiedName for both
- // builitin operator and TorchScript function
- // If the qualifiedName is not a builtin operator name, then treat it
- // as TorchScript function name
const std::string& qualifiedName = ivalues.back().toStringRef();
if (qualifiedName.rfind(BUILTIN_OP_NAMESPACE_) == 0) {
@@ -91,11 +58,9 @@
ivalues.pop_back();
// remove str_schema from ivalues
- return std::make_unique<ScriptCall>(op, std::move(ivalues));
+ return op;
} else {
- ivalues.pop_back();
- return std::make_unique<ScriptCall>(
- c10::QualifiedName(qualifiedName), std::move(ivalues));
+ TORCH_CHECK(false, "Unrecognized qualified name ", qualifiedName);
}
}
@@ -118,7 +83,8 @@
jit::unpickle(payload, payload_size, nullptr, &message.tensors());
auto values = value.toTuple()->elements();
- return fromIValues(values);
+ auto op = fromIValues(values);
+ return std::make_unique<ScriptCall>(op, std::move(values));
}
std::shared_ptr<Operator> ScriptCall::matchOperator(
diff --git a/torch/csrc/distributed/rpc/script_call.h b/torch/csrc/distributed/rpc/script_call.h
index 49f174a..c295476 100644
--- a/torch/csrc/distributed/rpc/script_call.h
+++ b/torch/csrc/distributed/rpc/script_call.h
@@ -14,23 +14,13 @@
using torch::jit::Operator;
// A ScriptCall instance represents an invocation of a builtin operator for a
-// TorchScript function. If it is a builtin operator, it
+// TorchScript function (not implemented yet). If it is a builtin operator, it
// contains a shared ptr to the `Operator` and a list of arguments.
-// If it is a TorchScript function, it contains a non empty qualifiedName string
-// to the TorchScript function schema name and a list of arguments.
class TORCH_API ScriptCall : public RpcCommandBase {
public:
- // Constructor for builitin operator call.
ScriptCall(std::shared_ptr<Operator> op, std::vector<at::IValue>&& args);
- // Constructor for TorchScript function call.
- ScriptCall(
- const c10::QualifiedName& qualifiedName,
- std::vector<at::IValue>&& args);
- bool hasOp() const;
std::shared_ptr<Operator> op() const;
- bool hasQualifiedName() const;
- const c10::QualifiedName qualifiedName() const;
// return the argument stack of this builtin operator
const std::vector<at::IValue>& stack() const;
std::vector<at::IValue>& stackRef();
@@ -42,7 +32,7 @@
protected:
virtual void toIValues(std::vector<at::IValue>& ivalues) const;
- static std::unique_ptr<ScriptCall> fromIValues(
+ static std::shared_ptr<Operator> fromIValues(
std::vector<at::IValue>& ivalues);
private:
@@ -55,9 +45,6 @@
// This field has value if this ScriptCall represents invocation of a builtin
// operator.
c10::optional<std::shared_ptr<Operator>> op_;
- // This field has non empty string if this ScriptCall represents invocation of
- // an annotated torchscript function defined by users.
- c10::optional<const c10::QualifiedName> qualifiedName_;
std::vector<at::IValue> stack_;
};
diff --git a/torch/csrc/distributed/rpc/script_functions.cpp b/torch/csrc/distributed/rpc/script_functions.cpp
deleted file mode 100644
index 8ad9091..0000000
--- a/torch/csrc/distributed/rpc/script_functions.cpp
+++ /dev/null
@@ -1,55 +0,0 @@
-#include <torch/csrc/distributed/rpc/script_functions.h>
-
-#include <torch/csrc/distributed/autograd/utils.h>
-#include <torch/csrc/distributed/rpc/message.h>
-#include <torch/csrc/distributed/rpc/python_rpc_handler.h>
-#include <torch/csrc/distributed/rpc/rpc_agent.h>
-#include <torch/csrc/distributed/rpc/script_call.h>
-#include <torch/csrc/distributed/rpc/utils.h>
-
-namespace torch {
-namespace distributed {
-namespace rpc {
-
-c10::intrusive_ptr<c10::ivalue::Future> rpcTorchscriptCall(
- const std::string& dst,
- const c10::QualifiedName& qualifiedName,
- std::vector<c10::IValue>& stack) {
- auto scriptCall =
- std::make_unique<ScriptCall>(qualifiedName, std::move(stack));
- auto agent = RpcAgent::getDefaultRpcAgent();
- auto futMessage = autograd::sendMessageWithAutograd(
- *agent, agent->getWorkerInfo(dst), std::move(*scriptCall).toMessage());
- // Get function return type to construct c10::ivalue::Future.
- // Script call only allows single IValue returned.
- auto returns = PythonRpcHandler::getInstance()
- .jitCompilationUnit()
- ->get_function(qualifiedName)
- .getSchema()
- .returns();
- TORCH_INTERNAL_ASSERT(
- returns.size() == 1,
- "Return value of an annotated torchScript function should be a single "
- "IValue.",
- returns.size());
- auto returnType = returns.at(0).type();
-
- // Create a JIT future and pass it to futMessage's callback to set state
- // of the JIT future.
- auto futPtr = c10::make_intrusive<c10::ivalue::Future>(returnType);
- futMessage->addCallback([futPtr](
- const rpc::Message& message,
- const c10::optional<utils::FutureError>& futErr) {
- if (futErr) {
- c10::ivalue::Future::FutureError jitFutErr(std::string((*futErr).what()));
- futPtr->markCompleted(std::move(jitFutErr));
- } else {
- futPtr->markCompleted(deserializeRespToIValue(message));
- }
- });
- return futPtr;
-}
-
-} // namespace rpc
-} // namespace distributed
-} // namespace torch
diff --git a/torch/csrc/distributed/rpc/script_functions.h b/torch/csrc/distributed/rpc/script_functions.h
deleted file mode 100644
index 6772193..0000000
--- a/torch/csrc/distributed/rpc/script_functions.h
+++ /dev/null
@@ -1,26 +0,0 @@
-#pragma once
-
-#include <ATen/core/ivalue.h>
-
-namespace torch {
-namespace distributed {
-namespace rpc {
-
-// This function sends an rpc call to run torchscript function, currently the
-// torchscript function could only be a user defined python function with
-// "@torch.jit.script" annotation. The torchscript function could not be
-// a class constructor, class method, instance method or a script module.
-// dst: destination worker name
-// qualifiedName: torchscript function qualified name string like
-// "moduleName::torchscriptFunctionName", e.g,
-// "dist_autograd_test::my_py_add"
-// stack: a bag of IValue args passed to torchscriptFunctionName
-// It returns c10::intrusive_ptr<ivalue::Future>
-c10::intrusive_ptr<c10::ivalue::Future> rpcTorchscriptCall(
- const std::string& dst,
- const c10::QualifiedName& qualifiedName,
- std::vector<c10::IValue>& stack);
-
-} // namespace rpc
-} // namespace distributed
-} // namespace torch
diff --git a/torch/csrc/distributed/rpc/script_remote_call.cpp b/torch/csrc/distributed/rpc/script_remote_call.cpp
index 068b6c8..633b8a7 100644
--- a/torch/csrc/distributed/rpc/script_remote_call.cpp
+++ b/torch/csrc/distributed/rpc/script_remote_call.cpp
@@ -47,9 +47,9 @@
auto retRRefId = ForkId::fromIValue(values.back());
values.pop_back();
- auto scriptCallPtr = ScriptCall::fromIValues(values);
+ auto op = ScriptCall::fromIValues(values);
return std::make_unique<ScriptRemoteCall>(
- scriptCallPtr->op(), std::move(values), retRRefId, retForkId);
+ op, std::move(values), std::move(retRRefId), std::move(retForkId));
}
} // namespace rpc
diff --git a/torch/csrc/distributed/rpc/utils.cpp b/torch/csrc/distributed/rpc/utils.cpp
index 156f35f..a987042 100644
--- a/torch/csrc/distributed/rpc/utils.cpp
+++ b/torch/csrc/distributed/rpc/utils.cpp
@@ -3,9 +3,7 @@
#include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_req.h>
#include <torch/csrc/distributed/autograd/rpc_messages/cleanup_autograd_context_resp.h>
#include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.h>
-#include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_resp.h>
#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
-#include <torch/csrc/distributed/autograd/utils.h>
#include <torch/csrc/distributed/rpc/python_call.h>
#include <torch/csrc/distributed/rpc/python_remote_call.h>
#include <torch/csrc/distributed/rpc/python_resp.h>
@@ -65,9 +63,7 @@
}
}
-std::unique_ptr<RpcCommandBase> deserializeResponse(
- const Message& response,
- MessageType& wrappedMsgType) {
+std::unique_ptr<RpcCommandBase> deserializeResponse(const Message& response) {
switch (response.type()) {
case MessageType::SCRIPT_RET: {
return ScriptResp::fromMessage(response);
@@ -88,23 +84,10 @@
return RRefAck::fromMessage(response);
}
case MessageType::FORWARD_AUTOGRAD_RESP: {
- std::unique_ptr<RpcCommandBase> rpcPtr =
- autograd::RpcWithAutograd::fromMessage(response);
- RpcCommandBase& rpc = *rpcPtr;
- auto& rpcWithAutograd = static_cast<autograd::RpcWithAutograd&>(rpc);
-
- // Attach 'recv' autograd function.
- addRecvRpcBackward(
- rpcWithAutograd.autogradMetadata(),
- rpcWithAutograd.tensors(),
- rpcWithAutograd.fromWorkerId());
-
- wrappedMsgType = rpcWithAutograd.wrappedMessageType();
-
- return std::move(rpcWithAutograd).moveWrappedRpc();
+ return autograd::RpcWithAutograd::fromMessage(response);
}
case MessageType::BACKWARD_AUTOGRAD_RESP: {
- return autograd::PropagateGradientsResp::fromMessage(response);
+ return autograd::RpcWithAutograd::fromMessage(response);
}
case MessageType::CLEANUP_AUTOGRAD_CONTEXT_RESP: {
return autograd::CleanupAutogradContextResp::fromMessage(response);
@@ -116,30 +99,6 @@
}
}
-IValue deserializeResptoIValueInternal(
- RpcCommandBase& rpc,
- MessageType messageType) {
- switch (messageType) {
- case MessageType::SCRIPT_RET: {
- auto& ret = static_cast<ScriptResp&>(rpc);
- return ret.value();
- }
- default: {
- TORCH_INTERNAL_ASSERT(
- false,
- "Response type ",
- messageType,
- " is not supported to be deserialized to IValue.");
- }
- }
-}
-
-IValue deserializeRespToIValue(const Message& message) {
- MessageType msgType = message.type();
- auto response = deserializeResponse(message, msgType);
- return deserializeResptoIValueInternal(*response, msgType);
-}
-
namespace {
// Helper for wireDeserialize() below.
diff --git a/torch/csrc/distributed/rpc/utils.h b/torch/csrc/distributed/rpc/utils.h
index a4f7506..f1c4d57 100644
--- a/torch/csrc/distributed/rpc/utils.h
+++ b/torch/csrc/distributed/rpc/utils.h
@@ -12,22 +12,9 @@
const Message& request);
// Given an RPC message received as a response over the wire, deserialize it
-// into the appropriate 'RpcCommandBase' type, if the response is
-// FORWARD_AUTOGRAD_RESP type, unwrap it, attach recvBackward() functions
-// to received tensors and set the wrappedMsgType to its wrapped message type.
+// into the appropriate 'RpcCommandBase' type.
TORCH_API std::unique_ptr<RpcCommandBase> deserializeResponse(
- const Message& response,
- MessageType& wrappedMsgType);
-
-// Given an RPC message received as a response over the wire, deserialize it
-// into the valid IValue if the message is for a script rpc result,
-// otherwise deserialize it into dummy none ivalue that will never be used.
-// In this deserialization, we also attach recv rpc backward functions if
-// needed.
-IValue deserializeResptoIValueInternal(
- RpcCommandBase& rpc,
- MessageType messageType);
-TORCH_API IValue deserializeRespToIValue(const Message& message);
+ const Message& response);
// Note: format is subject to change and intended for RPCs.
// For saving persistently to disk, use torch::save().
diff --git a/torch/distributed/rpc/__init__.py b/torch/distributed/rpc/__init__.py
index 7dd1e04..6f7455e 100644
--- a/torch/distributed/rpc/__init__.py
+++ b/torch/distributed/rpc/__init__.py
@@ -19,7 +19,6 @@
if is_available():
from .api import _init_rpc_backend, _require_initialized
- from .api import _rpc_sync_torchscript, _rpc_async_torchscript
from .api import * # noqa: F401
import torch.distributed.autograd as dist_autograd
diff --git a/torch/distributed/rpc/api.py b/torch/distributed/rpc/api.py
index 64769fd..7d923c7 100644
--- a/torch/distributed/rpc/api.py
+++ b/torch/distributed/rpc/api.py
@@ -7,7 +7,6 @@
_invoke_remote_python_udf,
_invoke_rpc_builtin,
_invoke_rpc_python_udf,
- _invoke_rpc_script,
_start_rpc_agent,
backend_registry,
)
@@ -25,7 +24,6 @@
import torch
import torch.distributed as dist
-from torch._jit_internal import _qualified_name
_agent = None
# NB: Ignoring RRef leaks during shutdown. Without this, applications have to
@@ -328,8 +326,8 @@
Arguments:
to (str or WorkerInfo): id or name of the destination worker.
- func (callable): any callable function. builtin or annotated TorchScript
- functions (like meth:`torch.add`) can be sent over RPC more efficiently.
+ func (callable): any callable function. builtin functions (like
+ :meth:`torch.add`) can be sent over RPC more efficiently.
args (tuple): the argument tuple for the ``func`` invocation.
kwargs (dict): is a dictionary of keyword arguments for the ``func``
invocation.
@@ -358,32 +356,9 @@
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
>>> rpc.shutdown()
-
- If invoking an annotated TorchScript function, then run the following
- code in two different processes:
-
- >>> # On worker 0:
- >>> @torch.jit.script
- >>> def my_script_add(t1, t2):
- >>> return torch.add(t1, t2)
- >>> import torch.distributed.rpc as rpc
- >>> rpc.init_rpc("worker0", rank=0, world_size=2)
- >>> ret = rpc.rpc_sync("worker1", my_script_add, args=(torch.ones(2), 3))
- >>> rpc.shutdown()
-
- >>> # On worker 1:
- >>> import torch.distributed.rpc as rpc
- >>> rpc.init_rpc("worker1", rank=1, world_size=2)
- >>> rpc.shutdown()
-
"""
- # If invoking an annotated TorchScript function,
- # call the internal API _rpc_sync_torchscript()
- if isinstance(func, torch.jit.ScriptFunction):
- return _rpc_sync_torchscript(to, _qualified_name(func), args, kwargs)
- else:
- fut = _invoke_rpc(to, func, RPCExecMode.SYNC, args, kwargs)
- return fut.wait()
+ fut = _invoke_rpc(to, func, RPCExecMode.SYNC, args, kwargs)
+ return fut.wait()
@_require_initialized
@@ -396,8 +371,8 @@
Arguments:
to (str or WorkerInfo): id or name of the destination worker.
- func (callable): any callable function. builtin or annotated TorchScript
- functions (like meth:`torch.add`) can be sent over RPC more efficiently.
+ func (callable): any callable function. builtin functions (like
+ :meth:`torch.add`) can be sent over RPC more efficiently.
args (tuple): the argument tuple for the ``func`` invocation.
kwargs (dict): is a dictionary of keyword arguments for the ``func``
invocation.
@@ -430,146 +405,6 @@
>>> import torch.distributed.rpc as rpc
>>> rpc.init_rpc("worker1", rank=1, world_size=2)
>>> rpc.shutdown()
-
- If invoking an annotated TorchScript function, then run the following
- code in two different processes:
-
- >>> # On worker 0:
- >>> @torch.jit.script
- >>> def my_script_add(t1, t2):
- >>> return torch.add(t1, t2)
- >>> import torch.distributed.rpc as rpc
- >>> rpc.init_rpc("worker0", rank=0, world_size=2)
- >>> fut = rpc.rpc_async("worker1", my_script_add, args=(torch.ones(2), 3))
- >>> ret = fut.wait()
- >>> rpc.shutdown()
-
- >>> # On worker 1:
- >>> import torch.distributed.rpc as rpc
- >>> rpc.init_rpc("worker1", rank=1, world_size=2)
- >>> rpc.shutdown()
"""
- # If invoking an annotated TorchScript function,
- # call the internal API _rpc_async_torchscript()
- if isinstance(func, torch.jit.ScriptFunction):
- fut = _rpc_async_torchscript(to, _qualified_name(func), args, kwargs)
- else:
- fut = _invoke_rpc(to, func, RPCExecMode.ASYNC, args, kwargs)
- return fut
-
-
-# All below private APIs are for making rpc torch script call that can be
-# serialized, deserialized and exectued in C++ without GIL.
-# These APIs will be binded to JIT and can be called in torch script
-# function/class/module in the future. But since JIT does not support torch
-# script function to be a jit type yet, the future binded APIs can only accept
-# qualified_name of the function as arg, that is why these APIs are made
-# to be private and different from above public rpc APIs.
-# Because JIT does not support torch script function to be a jit type, right now
-# these APIs can only accept torch script call to only be user annotated
-# torchscript function, they do not accept annotated torchscript class name or
-# script module class name or their class method name right now.
-@_require_initialized
-def _rpc_sync_torchscript(to, qualified_name, args=None, kwargs=None):
- r"""
- Make a blocking RPC call to run TorchScript function ``func`` on worker ``to``.
- RPC messages are sent and received in parallel to execution of Python code. This
- method is thread-safe.
-
- Arguments:
- to (str): name of the destination worker.
- qualified_name (str): qualifited name of python function annotated with
- @torch.jit.script
- (like ``moduleName::torchScriptFuncName``)
- can be sent over RPC more efficiently.
- args (tuple): the argument tuple for the ``func`` invocation.
- kwargs (dict): is a dictionary of keyword arguments for the ``func``
- invocation.
-
- Returns:
- Returns the result of running ``func`` on ``args`` and ``kwargs``.
-
- Example::
- Make sure that ``MASTER_ADDRESS`` and ``MASTER_PORT`` are set properly
- on both workers. Refer to :meth:`~torch.distributed.init_process_group`
- API for more details. For example,
-
- >>> export MASTER_ADDRESS=localhost
- >>> export MASTER_port=5678
-
- Then run the following code in two different processes:
-
- >>> # On worker 0:
- >>> @torch.jit.script
- >>> def my_script_add(t1, t2):
- >>> return torch.add(t1, t2)
- >>> import torch.distributed.rpc as rpc
- >>> from torch._jit_internal import _qualified_name
- >>> rpc.init_rpc("worker0", rank=0, world_size=2)
- >>> ret = rpc._rpc_sync_torchscript("worker1", _qualified_name(my_script_add), args=(torch.ones(2), 3))
- >>> rpc.shutdown()
-
- >>> # On worker 1:
- >>> import torch.distributed.rpc as rpc
- >>> rpc.init_rpc("worker1", rank=1, world_size=2)
- >>> rpc.shutdown()
- """
- args = args if args else ()
- kwargs = kwargs if kwargs else {}
- fut = _invoke_rpc_script(to, qualified_name, *args, **kwargs)
- return fut.wait()
-
-
-@_require_initialized
-def _rpc_async_torchscript(to, qualified_name, args=None, kwargs=None):
- r"""
- Make a non-blocking RPC call to run TorchScript function ``func`` on worker ``to``.
- RPC messages are sent and received in parallel to execution of Python code. This
- method is thread-safe. This method will immediately return a
- _pyFuture that can be awaited on.
-
- Arguments:
- to (str): name of the destination worker.
- qualified_name (str): qualifited name of python function annotated with
- @torch.jit.script
- (like ``moduleName::torchScriptFuncName``)
- can be sent over RPC more efficiently.
- args (tuple): the argument tuple for the ``func`` invocation.
- kwargs (dict): is a dictionary of keyword arguments for the ``func``
- invocation.
-
- Returns:
- Returns a _pyFuture object that can be waited
- on. When completed, the return value of ``func`` on ``args`` and
- ``kwargs`` can be retrieved from the _pyFuture object.
-
- Example::
- Make sure that ``MASTER_ADDRESS`` and ``MASTER_PORT`` are set properly
- on both workers. Refer to :meth:`~torch.distributed.init_process_group`
- API for more details. For example,
-
- >>> export MASTER_ADDRESS=localhost
- >>> export MASTER_port=5678
-
- Then run the following code in two different processes:
-
- >>> # On worker 0:
- >>> @torch.jit.script
- >>> def my_script_add(t1, t2):
- >>> return torch.add(t1, t2)
- >>> import torch.distributed.rpc as rpc
- >>> from torch._jit_internal import _qualified_name
- >>> rpc.init_rpc("worker0", rank=0, world_size=2)
- >>> fut = rpc._rpc_async_torchscript("worker1", _qualified_name(my_script_add), args=(torch.ones(2), 3))
- >>> ret = fut.wait()
- >>> rpc.shutdown()
-
- >>> # On worker 1:
- >>> import torch.distributed.rpc as rpc
- >>> rpc.init_rpc("worker1", rank=1, world_size=2)
- >>> rpc.shutdown()
- """
- args = args if args else ()
- kwargs = kwargs if kwargs else {}
- fut = _invoke_rpc_script(to, qualified_name, *args, **kwargs)
+ fut = _invoke_rpc(to, func, RPCExecMode.ASYNC, args, kwargs)
return fut