Distributed Autograd - FAST mode backward pass implementation. (#27022)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/27022

This change implements the "FAST" mode distributed autograd backward
pass as described in https://github.com/pytorch/pytorch/issues/23110.

At a high level the backward pass works as follows:
1. We start by computing dependencies on the node that calls
`torch.distributed.backward`.
2. This node computes the dependencies starting from the root nodes provided in
the backward call and all the 'send' functions present in the current autograd
context. The "FAST" mode assumes all 'send' functions are part of the autograd
computation.
3. Once the dependency computation is done, the distributed autograd engine
calls the local autograd engine to execute the autograd graph. Note that the
autograd graph on a single node is not necessarily connected because of
inter-node communication. As a result, we have special handling to ensure the
local autograd engine ensures we execute the entire graph starting from the
provided roots and all 'send' functions on the node.
4. When the local autograd engine hits a 'recv' function, it performs an async
RPC to send the gradients over to the appropriate node and stores a future in
the autograd context to keep track of this RPC.
5. On the destination node, the appropriate 'send' function is looked up and
enqueued on the local autograd engine. If this is the first time the node is
hearing about this autograd context id on the backward pass, then the node
computes dependencies for the local autograd engine.
6. As part of compute dependencies, the distributed autograd engine discovers
all leaf nodes and ensures those are passed as 'outputs' to the local autograd
engine. This avoids running the 'AccumulateGrad' function.
7. The gradients computed for the leaf nodes are then actually accumulated in
`DistAutogradContext` for the appropriate autograd context id.
8. The distributed autograd engine waits for the local autograd engine
to complete and also waits for all the 'Futures' (stored in 4.) for respective
RPCs to finish.

We have made the following changes to the local autograd engine for this
purpose:

1. Expose GraphTask and NodeTask so that the distributed autograd engine can
use them.
2. Expose a `execute_with_graph_task` API which gives the distributed engine
to build a GraphTask and pass it to the local autograd engine.
3. Expose a `enqueue_on_cpu` API, which allows the distributed engine to build
a `NodeTask` for a 'send' function and enqueue it on the local autograd engine.

In addition to this a few general improvements:
1. Added a `PropagateGradients` RPC call for the 'recv' function to pass
gradients to the appropriate node during the backward pass.
2. Use IValues as much as possible in serialization for RpcWithAutograd.
3. If Future.wait(), contains a message type EXCEPTION, we throw an appropriate
exception instead of just returning the message. This is inline with what most
Future.wait() APIs do.
4. Added a `get_gradients(context_id)` API which allows users to retrieve a map
from Tensor to respective gradient for the provided context_id on the local
node.
ghstack-source-id: 91794926

Test Plan: unit tests.

Differential Revision: D17652615

fbshipit-source-id: 96f65c52adb2706ee29f4b49e1655afaa0a3bec3
diff --git a/caffe2/CMakeLists.txt b/caffe2/CMakeLists.txt
index 33fcf93..c16bb4d 100644
--- a/caffe2/CMakeLists.txt
+++ b/caffe2/CMakeLists.txt
@@ -481,16 +481,21 @@
       ${TORCH_SRC_DIR}/csrc/api/src/jit.cpp
       ${TORCH_SRC_DIR}/csrc/distributed/autograd/context/dist_autograd_container.cpp
       ${TORCH_SRC_DIR}/csrc/distributed/autograd/context/dist_autograd_context.cpp
+      ${TORCH_SRC_DIR}/csrc/distributed/autograd/engine/dist_engine.cpp
       ${TORCH_SRC_DIR}/csrc/distributed/autograd/functions/recvrpc_backward.cpp
       ${TORCH_SRC_DIR}/csrc/distributed/autograd/functions/sendrpc_backward.cpp
+      ${TORCH_SRC_DIR}/csrc/distributed/autograd/rpc_messages/autograd_metadata.cpp
+      ${TORCH_SRC_DIR}/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.cpp
+      ${TORCH_SRC_DIR}/csrc/distributed/autograd/rpc_messages/propagate_gradients_resp.cpp
+      ${TORCH_SRC_DIR}/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp
       ${TORCH_SRC_DIR}/csrc/distributed/autograd/utils.cpp
       ${TORCH_SRC_DIR}/csrc/distributed/rpc/future_message.cpp
       ${TORCH_SRC_DIR}/csrc/distributed/rpc/message.cpp
       ${TORCH_SRC_DIR}/csrc/distributed/rpc/python_remote_call.cpp
       ${TORCH_SRC_DIR}/csrc/distributed/rpc/python_udf_call.cpp
       ${TORCH_SRC_DIR}/csrc/distributed/rpc/python_udf_resp.cpp
+      ${TORCH_SRC_DIR}/csrc/distributed/rpc/rpc_agent.cpp
       ${TORCH_SRC_DIR}/csrc/distributed/rpc/request_callback.cpp
-      ${TORCH_SRC_DIR}/csrc/distributed/rpc/rpc_with_autograd.cpp
       ${TORCH_SRC_DIR}/csrc/distributed/rpc/rref_proto.cpp
       ${TORCH_SRC_DIR}/csrc/distributed/rpc/script_call.cpp
       ${TORCH_SRC_DIR}/csrc/distributed/rpc/script_remote_call.cpp
diff --git a/test/cpp/dist_autograd/test_dist_autograd.cpp b/test/cpp/dist_autograd/test_dist_autograd.cpp
index 7af16e6..1238179 100644
--- a/test/cpp/dist_autograd/test_dist_autograd.cpp
+++ b/test/cpp/dist_autograd/test_dist_autograd.cpp
@@ -3,8 +3,8 @@
 #include <ATen/ATen.h>
 #include <torch/csrc/distributed/autograd/context/dist_autograd_container.h>
 #include <torch/csrc/distributed/autograd/context/dist_autograd_context.h>
+#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
 #include <torch/csrc/distributed/autograd/utils.h>
-#include <torch/csrc/distributed/rpc/rpc_with_autograd.h>
 #include <torch/torch.h>
 
 using namespace torch::distributed::autograd;
@@ -20,38 +20,6 @@
 
 DistAutogradContainer* DistAutogradTest::autogradContainer_ = nullptr;
 
-TEST_F(DistAutogradTest, TestSendFunction) {
-  // Initialize input tensors requiring grad.
-  auto options = at::TensorOptions().requires_grad(true);
-  auto in1 = torch::ones({3, 3}, options);
-  auto in2 = torch::ones({3, 3}, options);
-  ASSERT_FALSE(in1.grad().defined());
-  ASSERT_FALSE(in2.grad().defined());
-
-  autogradContainer_->newContext();
-  DistAutogradContext& autogradContext = autogradContainer_->currentContext();
-  // Attach the send autograd function to tensors.
-  std::vector<torch::Tensor> tensors = {in1, in2};
-  addSendRpcBackward(autogradContext, AutogradMetadata(1, 1), tensors);
-  auto send_function = autogradContext.sendFunctions()[1];
-  ASSERT_NE(send_function, nullptr);
-
-  // Build loss and attach it as input to send autograd function.
-  auto o1 = torch::autograd::Variable(torch::ones({3, 3}));
-  auto edge = torch::autograd::Edge(send_function, 0);
-  o1.set_gradient_edge(edge);
-  auto o2 = torch::autograd::Variable(torch::ones({3, 3}));
-  edge = torch::autograd::Edge(send_function, 1);
-  o2.set_gradient_edge(edge);
-  auto loss = torch::add(o1, o2);
-
-  // Run backwards pass and verify gradients accumulated.
-  auto gradient = torch::autograd::Variable(torch::rand({3, 3}));
-  loss.backward(gradient, false, false);
-  ASSERT_TRUE(in1.grad().defined());
-  ASSERT_TRUE(in2.grad().defined());
-}
-
 TEST_F(DistAutogradTest, TestSendFunctionInvalidInputs) {
   auto options = at::TensorOptions().requires_grad(true);
   auto in1 = torch::ones({3, 3}, options);
@@ -64,12 +32,12 @@
   addSendRpcBackward(autogradContext, AutogradMetadata(1, 1), tensors);
   auto send_function = autogradContext.sendFunctions()[1];
 
-  // Build loss and attach it as input to send autograd function.
-  auto loss = torch::autograd::Variable(torch::ones({3, 3}));
-  loss.set_gradient_edge(torch::autograd::Edge(send_function, 1));
+  // This should fail since the SendRpcBackward function shouldn't receive any
+  // inputs grad.
+  EXPECT_THROW(send_function->apply({in1, in2}), c10::Error);
 
-  // This should fail since the SendRpcBackward function is looking for two
-  // inputs and as a result encounters an undefined grad.
-  EXPECT_THROW(
-      loss.backward(torch::autograd::Variable(), false, false), c10::Error);
+  // This should fail since the SendRpcBackward function encounters an undefined
+  // grad.
+  send_function->setGrads({in1, torch::autograd::Variable()});
+  EXPECT_THROW(send_function->apply({}), c10::Error);
 }
diff --git a/test/dist_autograd_test.py b/test/dist_autograd_test.py
index 56503f4..d56c866 100644
--- a/test/dist_autograd_test.py
+++ b/test/dist_autograd_test.py
@@ -8,6 +8,7 @@
 import torch.distributed.rpc as rpc
 from dist_utils import INIT_METHOD_TEMPLATE, dist_init
 
+import threading
 
 prev_rank_rpc_done = False
 prev_rank_context_id = 0
@@ -19,11 +20,49 @@
     prev_rank_rpc_done = True
     prev_rank_context_id = context_id
 
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+
+class SimulateBackwardError(Function):
+    @staticmethod
+    def forward(ctx, input):
+        return input
+
+    @staticmethod
+    @once_differentiable
+    def backward(ctx, input):
+        raise Exception('Simulate error on backward pass')
+
+from enum import Enum
+
+class ExecMode(Enum):
+    LOCAL = 1  # Run the operation locally.
+    REMOTE = 2  # Run the operation using RPC.
+
 
 @unittest.skipIf(
     not torch._six.PY3, "Pytorch distributed autograd package " "does not support python2"
 )
 class DistAutogradTest(object):
+
+    def _exec_func(self, exec_mode, method, *args):
+        if ExecMode.LOCAL == exec_mode:
+            if len(args) == 1 and isinstance(args[0], list):
+                return method(*args[0])
+            return method(*args)
+        else:
+            return rpc.rpc_sync('worker{}'.format(self._next_rank()), method,
+                                args=(args))
+
+    def _next_rank(self):
+        if hasattr(self, 'dst_rank'):
+            self.dst_rank = (self.dst_rank + 1) % self.world_size
+            if self.dst_rank == self.rank:
+                self._next_rank()
+        else:
+            self.dst_rank = (self.rank + 1) % self.world_size
+        return self.dst_rank
+
     @property
     def world_size(self):
         return 4
@@ -61,6 +100,14 @@
                 dist_autograd._retrieve_context(context_id)
 
     @dist_init
+    def test_nested_context(self):
+        with dist_autograd.context() as context_id:
+            # Nested contexts not supported.
+            with self.assertRaisesRegex(RuntimeError, "Already have an autograd context id for this thread"):
+                with dist_autograd.context() as context_id:
+                    pass
+
+    @dist_init
     def test_autograd_functions(self):
         dst_rank = (self.rank + 1) % self.world_size
         with dist_autograd.context() as context_id:
@@ -136,14 +183,13 @@
 
     @dist_init
     def test_rpc_complex_args(self):
-        dst_rank = (self.rank + 1) % self.world_size
         with dist_autograd.context() as context_id:
             num_tensors = 10
             tensors = []
             for i in range(num_tensors):
                 tensors.append(torch.ones(3, 3, requires_grad=(i % 2 == 0)))
             ret = rpc.rpc_sync(
-                "worker{}".format(dst_rank), torch.stack, args=(tensors,)
+                "worker{}".format(self._next_rank()), torch.stack, args=(tensors,)
             )
             self.assertEqual(torch.stack(tensors), ret)
 
@@ -160,3 +206,257 @@
                     self.assertEqual(tensors[i], next_funcs[i][0].variable)
                 else:
                     self.assertIsNone(next_funcs[i][0])
+
+    @dist_init
+    def test_error_in_context(self):
+        with dist_autograd.context() as context_id:
+            t1 = torch.rand(3, 3, requires_grad=True)
+            t2 = torch.rand(6, 6, requires_grad=True)
+
+
+            with self.assertRaises(RuntimeError):
+                # This should throw an error since matrix sizes don't match.
+                rpc.rpc_sync('worker{}'.format(self._next_rank()), torch.matmul,
+                             args=(t1, t2))
+
+    def _verify_backwards(self, exec_mode, tensors, context_id, local_grads, *args):
+        if exec_mode == ExecMode.REMOTE:
+            self._verify_backwards_remote(tensors, context_id, local_grads, *args)
+        else:
+            torch.autograd.backward(tensors)
+            return [arg.grad for arg in args]
+
+    def _verify_backwards_remote(self, tensors, context_id, local_grads, *args):
+        dist_autograd.backward(tensors)
+
+        # Verify grads were accumulated appropriately.
+        grads = dist_autograd.get_gradients(context_id)
+        nargs = len(args)
+        ngrads = 0
+        for i in range(0, nargs):
+            if local_grads[i] is not None:
+                self.assertIn(args[i], grads)
+                self.assertEqual(local_grads[i], grads[args[i]])
+                ngrads += 1
+            else:
+                self.assertNotIn(args[i], grads)
+
+        self.assertEqual(ngrads, len(grads))
+
+
+    @dist_init
+    def test_backward_simple(self):
+        # Run the same code locally and with dist autograd and verify gradients
+        # are same.
+        local_grads = None
+        t1 = torch.rand((3, 3), requires_grad=True)
+        t2 = torch.rand((3, 3), requires_grad=True)
+        for exec_mode in [ExecMode.LOCAL, ExecMode.REMOTE]:
+            with dist_autograd.context() as context_id:
+                ret = self._exec_func(exec_mode, torch.add, t1, t2)
+                loss = ret.sum()
+                local_grads = self._verify_backwards(exec_mode, [loss], context_id, local_grads, t1, t2)
+
+    @dist_init
+    def test_backward_multiple_round_trips(self):
+        local_grads = None
+        t1 = torch.rand((3, 3), requires_grad=True)
+        t2 = torch.rand((3, 3))
+        t3 = torch.rand((3, 3), requires_grad=True)
+        t4 = torch.rand((3, 3))
+        t5 = torch.rand((3, 3), requires_grad=True)
+
+        for exec_mode in [ExecMode.LOCAL, ExecMode.REMOTE]:
+            with dist_autograd.context() as context_id:
+                # Multiple RPCs between different nodes.
+                val = self._exec_func(exec_mode, torch.add, t1, t2)
+                val = self._exec_func(exec_mode, torch.mul, t3, val)
+                s1 = self._exec_func(exec_mode, torch.stack, (t4, val))
+                s2 = self._exec_func(exec_mode, torch.stack, (t5, val))
+                val = self._exec_func(exec_mode, torch.bmm, s1, s2)
+                val = self._exec_func(exec_mode, torch.matmul, val, val)
+                loss = val.sum()
+
+                local_grads = self._verify_backwards(exec_mode, [loss], context_id, local_grads, t1, t2, t3, t4, t5)
+
+    @dist_init
+    def test_backward_different_tensor_dims(self):
+        local_grads = None
+        t1 = torch.rand((4, 6), requires_grad=True)
+        t2 = torch.rand((6, 5))
+        t3 = torch.rand((5, 7), requires_grad=True)
+        t4 = torch.rand((7, 9))
+
+        for exec_mode in [ExecMode.LOCAL, ExecMode.REMOTE]:
+            with dist_autograd.context() as context_id:
+                val = self._exec_func(exec_mode, torch.matmul, t1, t2)
+                val = self._exec_func(exec_mode, torch.chain_matmul, [val, t3, t4])
+                loss = val.sum()
+
+                local_grads = self._verify_backwards(exec_mode, [loss], context_id, local_grads, t1, t2, t2, t3, t4)
+
+    @dist_init
+    def test_backward_unused_tensors(self):
+        local_grads = None
+        t1 = torch.rand((3, 3), requires_grad=True)
+        t2 = torch.rand((3, 3), requires_grad=True)
+        t3 = torch.rand((3, 3), requires_grad=True)
+        for exec_mode in [ExecMode.LOCAL, ExecMode.REMOTE]:
+            with dist_autograd.context() as context_id:
+                s = self._exec_func(exec_mode, torch.stack, (t1, t2, t3))
+                val = self._exec_func(exec_mode, torch.matmul, torch.narrow(s, 0, 0, 1), torch.narrow(s, 0, 2, 1))
+
+                loss = val.sum()
+                local_grads = self._verify_backwards(exec_mode, [loss], context_id, local_grads, t1, t2, t3)
+
+    @dist_init
+    def test_backward_multiple_output_tensors(self):
+        local_grads = None
+        t = torch.rand((10, 2), requires_grad=True)
+        for exec_mode in [ExecMode.LOCAL, ExecMode.REMOTE]:
+            with dist_autograd.context() as context_id:
+                tensor_list = self._exec_func(exec_mode, torch.split, t, 2)
+                t1 = tensor_list[0]
+                t2 = tensor_list[2]
+                t3 = tensor_list[4]
+
+                val = self._exec_func(exec_mode, torch.chain_matmul, [t1, t2, t3])
+
+                loss = val.sum()
+                local_grads = self._verify_backwards(exec_mode, [loss], context_id, local_grads, t)
+
+    def _run_test_backward_unused_send_function_in_thread(self):
+        with dist_autograd.context() as context_id:
+            t1 = torch.rand((3, 3), requires_grad=True)
+            t2 = torch.rand((3, 3), requires_grad=True)
+
+            # We don't use the result of an RPC function, as a result the
+            # backward pass would hang in the "FAST" mode.
+            res = rpc.rpc_sync('worker{}'.format(self._next_rank()), torch.add,
+                               args=(t1, t2))
+
+            val = torch.mul(t1, t2)
+
+            # Run backward, this would hang forever.
+            dist_autograd.backward([val.sum()])
+
+
+    @dist_init
+    def test_backward_unused_send_function(self):
+        # Run the test in a thread which would never finish.
+        t = threading.Thread(target=self._run_test_backward_unused_send_function_in_thread)
+        t.daemon = True
+        t.start()
+        t.join(10)  # Wait for 10s.
+
+        # Verify thread is still alive (indicating backward hasn't completed yet).
+        self.assertTrue(t.is_alive())
+
+    @dist_init
+    def test_backward_autograd_engine_error(self):
+        with dist_autograd.context() as context_id:
+            t1 = torch.rand((3, 3), requires_grad=True)
+            t2 = torch.rand((3, 3), requires_grad=True)
+            t3 = SimulateBackwardError.apply(t1)
+
+            # Run multiple round trips across different nodes and verify the
+            # original node receives an error thrown on a node deep in the chain.
+            val = rpc.rpc_sync('worker{}'.format(self._next_rank()), torch.add,
+                               args=(t2, t3))
+            val = rpc.rpc_sync('worker{}'.format(self._next_rank()), torch.mul,
+                               args=(val, t2))
+            val = rpc.rpc_sync('worker{}'.format(self._next_rank()), torch.matmul,
+                               args=(val, t2))
+            val = rpc.rpc_sync('worker{}'.format(self._next_rank()), torch.div,
+                               args=(val, t2))
+
+            with self.assertRaises(RuntimeError):
+                # Run backwards, and validate we receive an error.
+                dist_autograd.backward([val.sum()])
+
+    @dist_init
+    @unittest.skip("Skipping this test temporarily since ProcessGroupAgent does not report errors on node failures")
+    def test_backward_node_failure(self):
+        with dist_autograd.context() as context_id:
+            t1 = torch.rand((3, 3), requires_grad=True)
+            t2 = torch.rand((3, 3), requires_grad=True)
+
+            res = rpc.rpc_sync('worker{}'.format(self._next_rank()), torch.add,
+                               args=(t1, t2))
+
+            if self.rank == 0:
+                # Wait a bit for all other nodes to die.
+                time.sleep(3)
+                with self.assertRaises(RuntimeError):
+                    # Run backwards, and validate we receive an error since all
+                    # other nodes are dead.
+                    dist_autograd.backward([res.sum()])
+            else:
+                # Kill all other nodes.
+                sys.exit(0)
+
+    @dist_init
+    def test_backward_without_context(self):
+        t1 = torch.rand((3, 3), requires_grad=True)
+        t2 = torch.rand((3, 3), requires_grad=True)
+
+        with self.assertRaisesRegex(RuntimeError, "Current thread doesn't have a valid autograd context"):
+            res = rpc.rpc_sync('worker{}'.format(self._next_rank()), torch.add,
+                               args=(t1, t2))
+            dist_autograd.backward([res.sum()])
+
+    @dist_init
+    def test_backward_without_rpc(self):
+        dst_rank = self.rank
+        with dist_autograd.context() as context_id:
+            t1 = torch.rand((3, 3), requires_grad=True)
+            t2 = torch.rand((3, 3), requires_grad=True)
+            t3 = torch.add(t1, t2)
+
+            dist_autograd.backward([t3.sum()])
+            grads = dist_autograd.get_gradients(context_id)
+            self.assertEqual(2, len(grads))
+            self.assertIn(t1, grads)
+            self.assertIn(t2, grads)
+            self.assertEqual(torch.ones(3, 3), grads[t1])
+            self.assertEqual(torch.ones(3, 3), grads[t2])
+
+    @dist_init
+    def test_backward_invalid_args(self):
+        with dist_autograd.context() as context_id:
+
+            with self.assertRaisesRegex(TypeError, "incompatible function arguments"):
+                dist_autograd.backward(None)
+
+            with self.assertRaisesRegex(RuntimeError, "No tensors provided for gradient computation"):
+                dist_autograd.backward([])
+
+            with self.assertRaisesRegex(RuntimeError, "requires_grad not set on"):
+                t = torch.rand(3, 3)
+                dist_autograd.backward([t])
+
+            with self.assertRaisesRegex(RuntimeError, "is not a scalar, all roots need to be scalar"):
+                t = torch.rand(3, 3, requires_grad=True)
+                dist_autograd.backward([t])
+
+            with self.assertRaisesRegex(RuntimeError, "does not have a valid gradient function"):
+                t = torch.rand(1, requires_grad=True)
+                dist_autograd.backward([t])
+
+    @dist_init
+    def test_backward_multiple_roots(self):
+        local_grads = None
+        t1 = torch.rand((3, 3), requires_grad=True)
+        t2 = torch.rand((3, 3), requires_grad=True)
+        for exec_mode in [ExecMode.LOCAL, ExecMode.REMOTE]:
+            with dist_autograd.context() as context_id:
+                r1 = self._exec_func(exec_mode, torch.add, t1, t2).sum()
+                r2 = self._exec_func(exec_mode, torch.mul, t1, t2).sum()
+                r3 = self._exec_func(exec_mode, torch.cos, t1).sum()
+                r4 = self._exec_func(exec_mode, torch.div, t1, t2).sum()
+
+                local_grads = self._verify_backwards(exec_mode, [r1, r2, r3, r4], context_id, local_grads, t1, t2)
+
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/test/dist_utils.py b/test/dist_utils.py
index a760e48..fd95eb1 100644
--- a/test/dist_utils.py
+++ b/test/dist_utils.py
@@ -38,11 +38,13 @@
     def wrapper(self, *arg, **kwargs):
         self.worker_id = self.rank
         dist.init_process_group(backend="gloo", init_method=self.init_method)
+        # Use enough 'num_send_recv_threads' until we fix https://github.com/pytorch/pytorch/issues/26359
         rpc.init_model_parallel(
             self_name="worker%d" % self.rank,
             backend=TEST_CONFIG.rpc_backend,
             self_rank=self.rank,
             init_method=self.init_method,
+            num_send_recv_threads=16
         )
         test_method(self, *arg, **kwargs)
         rpc.join_rpc()
diff --git a/test/rpc_test.py b/test/rpc_test.py
index d985bb6..6f6fe19 100644
--- a/test/rpc_test.py
+++ b/test/rpc_test.py
@@ -226,7 +226,7 @@
             rpc.rpc_sync(self_worker_name, torch.add, args=(torch.ones(2, 2), 1))
 
     @mock.patch.object(torch.distributed.autograd, "_init")
-    @mock.patch.object(torch.distributed.rpc.api, "_init_rref_context")
+    @mock.patch.object(torch.distributed.rpc.api, "_init_rpc_agent")
     def test_register_rpc_backend_and_init_rpc_backend(
         self, mock_init_rref_context, mock_dist_autograd_init
     ):
diff --git a/tools/build_variables.py b/tools/build_variables.py
index c5d47b3..b2b2898 100644
--- a/tools/build_variables.py
+++ b/tools/build_variables.py
@@ -52,15 +52,20 @@
     "torch/csrc/distributed/autograd/utils.cpp",
     "torch/csrc/distributed/autograd/context/dist_autograd_container.cpp",
     "torch/csrc/distributed/autograd/context/dist_autograd_context.cpp",
+    "torch/csrc/distributed/autograd/engine/dist_engine.cpp",
     "torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp",
     "torch/csrc/distributed/autograd/functions/sendrpc_backward.cpp",
+    "torch/csrc/distributed/autograd/rpc_messages/autograd_metadata.cpp",
+    "torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.cpp",
+    "torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_resp.cpp",
+    "torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp",
     "torch/csrc/distributed/rpc/future_message.cpp",
     "torch/csrc/distributed/rpc/message.cpp",
     "torch/csrc/distributed/rpc/python_remote_call.cpp",
     "torch/csrc/distributed/rpc/python_udf_call.cpp",
     "torch/csrc/distributed/rpc/python_udf_resp.cpp",
+    "torch/csrc/distributed/rpc/rpc_agent.cpp",
     "torch/csrc/distributed/rpc/request_callback.cpp",
-    "torch/csrc/distributed/rpc/rpc_with_autograd.cpp",
     "torch/csrc/distributed/rpc/rref_proto.cpp",
     "torch/csrc/distributed/rpc/script_call.cpp",
     "torch/csrc/distributed/rpc/script_remote_call.cpp",
@@ -275,7 +280,6 @@
         "torch/csrc/distributed/rpc/python_functions.cpp",
         "torch/csrc/distributed/rpc/python_rpc_handler.cpp",
         "torch/csrc/distributed/rpc/request_callback_impl.cpp",
-        "torch/csrc/distributed/rpc/rpc_agent.cpp",
         "torch/csrc/distributed/rpc/rref.cpp",
         "torch/csrc/distributed/rpc/rref_context.cpp",
         "torch/csrc/jit/init.cpp",
diff --git a/torch/CMakeLists.txt b/torch/CMakeLists.txt
index 7372180..f85e06f 100644
--- a/torch/CMakeLists.txt
+++ b/torch/CMakeLists.txt
@@ -233,7 +233,6 @@
         ${TORCH_SRC_DIR}/csrc/distributed/rpc/python_functions.cpp
         ${TORCH_SRC_DIR}/csrc/distributed/rpc/python_rpc_handler.cpp
         ${TORCH_SRC_DIR}/csrc/distributed/rpc/request_callback_impl.cpp
-        ${TORCH_SRC_DIR}/csrc/distributed/rpc/rpc_agent.cpp
         ${TORCH_SRC_DIR}/csrc/distributed/rpc/rref_context.cpp
         ${TORCH_SRC_DIR}/csrc/distributed/rpc/rref.cpp
         )
diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp
index 2a43c91..bfae098 100644
--- a/torch/csrc/autograd/engine.cpp
+++ b/torch/csrc/autograd/engine.cpp
@@ -10,7 +10,6 @@
 #include <ATen/DeviceGuard.h>
 #include <ATen/ExpandUtils.h>
 #include <ATen/Parallel.h>
-#include <ATen/ThreadLocalDebugInfo.h>
 #include <c10/util/Exception.h>
 #include <c10/core/Stream.h>
 #include <c10/core/Event.h>
@@ -36,9 +35,6 @@
 
 namespace torch { namespace autograd {
 
-// NB: -1 indicates the CPU worker!
-static constexpr int NO_DEVICE = -2;
-
 // Threads spawned by the engine are assigned a constant 'worker_device'
 // specifying what device they process work for.  This variable is initialized
 // at thread creation time and is constant afterwards.  This is used when
@@ -62,26 +58,6 @@
 // Total nested reentrant backwards calls over all threads for workder_device
 static thread_local int total_depth = 0;
 
-struct NodeTask {
-  GraphTask* base_;
-  std::shared_ptr<Node> fn_;
-  // This buffer serves as an implicit "addition" node for all of the
-  // gradients flowing here.  Once all the dependencies are finished, we
-  // use the contents of this buffer to run the function.
-  InputBuffer inputs_;
-  // When worker receives a task with isShutdownTask = true, it will immediately
-  // exit. The engine sends a shutdown task to every queue upon its destruction.
-  bool isShutdownTask_;
-
-  int getReentrantDepth() const;
-
-  NodeTask(GraphTask* base, std::shared_ptr<Node> fn, InputBuffer inputs, bool isShutdownTask = false)
-    : base_(base)
-    , fn_(std::move(fn))
-    , inputs_(std::move(inputs))
-    , isShutdownTask_(isShutdownTask) {}
-};
-
 // Returns true when t2 should be (weakly) BEFORE t1 in the queue.
 // Shutdown tasks are first and then empty NodeTask are next.
 struct CompareNodeTaskTime {
@@ -107,7 +83,11 @@
   // To protect read and writes to heap_
   std::mutex mutex_;
 
-  void push(NodeTask item);
+  // incrementOutstandingTasks indicates whether or not we should increment
+  // 'outstanding_tasks_' for the associated GraphTask. This should mostly
+  // always be true, see the doc for 'enqueue_blocked_task_on_cpu' for when we
+  // might set this to false.
+  void push(NodeTask item, bool incrementOutstandingTasks = true);
   void pushShutdownTask();
   NodeTask pop();
 };
@@ -171,83 +151,17 @@
 // the leaf streams with the default streams is sufficient to implement
 // the historic behavior.
 
-// GraphTask holds metadata needed for a single execution of backward()
-struct GraphTask {
-  std::exception_ptr exception_;
-  // Indicates if an error occurred while executing any task.  When this is
-  // true, it signals all threads to stop executing.
-  std::atomic_bool has_error_;
-  std::atomic<uint64_t> outstanding_tasks_;
-  // It is safe to read grad_mode_ and keep_graph_ without synchronization
-  bool keep_graph_;
-  bool grad_mode_;
-
-  // To protect reads/writes to no_ready_, dependencies_ , captured_vars_ and
-  // exception_
-  std::mutex mutex_;
-  // Notified when a task finishes executing.  Check outstanding_tasks_ to see
-  // if all tasks are done.
-  std::condition_variable not_done_;
-  std::unordered_map<Node*, InputBuffer> not_ready_;
-  std::unordered_map<Node*, int> dependencies_;
-
-  struct ExecInfo {
-    struct Capture {
-      Capture(int input_idx, int output_idx) : input_idx_(input_idx), output_idx_(output_idx) {}
-      int input_idx_; // within Node inputs
-      int output_idx_; // within the output vector of a GraphTask
-    };
-
-    bool should_execute() const {
-      return needed_ || captures_;
-    }
-
-    bool needed_ = false;
-    std::unique_ptr<std::vector<Capture>> captures_;
-  };
-  // Exec info has a bit complicated semantics. If it's empty, it means the task is
-  // run in a "default" mode, which means that all next_edges we encounter should
-  // get executed. If it's not empty, only functions that have an entry and this entry
-  // has needed == True should be executed.
-  // exec_info_.empty() means it's .backward(), otherwise it's .grad().
-  // exec_info_ is safe to read without synchronization
-  std::unordered_map<Node*, ExecInfo> exec_info_;
-  std::vector<Variable> captured_vars_;
-  std::shared_ptr<at::ThreadLocalDebugInfoBase> debug_info_ =
-      at::getThreadLocalDebugInfo();
-  std::unordered_set<c10::Stream> leaf_streams;
-
-  void init_to_execute(Node& graph_root, const edge_list& outputs);
-
-  // The value of worker_device in the thread that created this task.
-  // See Note [Reentrant backwards]
-  // Safe to read owner_ and reentrant_depth_ without synchronizaton
-  int owner_;
-  // The number of parent graph tasks for this graph task
-  const int reentrant_depth_;
-
-  bool can_checkpoint() {
-    return exec_info_.empty();
-  }
-
-  GraphTask(bool keep_graph, bool grad_mode, int reentrant_depth)
-    : has_error_(false)
-    , outstanding_tasks_(0)
-    , keep_graph_(keep_graph)
-    , grad_mode_(grad_mode)
-    , owner_(NO_DEVICE)
-    , reentrant_depth_(reentrant_depth) {}
-};
-
 int NodeTask::getReentrantDepth() const {
   return base_->reentrant_depth_;
 }
 
-auto ReadyQueue::push(NodeTask item) -> void {
+auto ReadyQueue::push(NodeTask item, bool incrementOutstandingTasks) -> void {
   {
     // Lock mutex for writing to heap_
     std::lock_guard<std::mutex> lock(mutex_);
-    ++item.base_->outstanding_tasks_;
+    if (incrementOutstandingTasks) {
+      ++item.base_->outstanding_tasks_;
+    }
     heap_.push(std::move(item));
   }
   not_empty_.notify_one();
@@ -458,8 +372,10 @@
       expected == actual.toBackend(toDense(actual.backend())));
 }
 
-template<typename F>
-static void validate_outputs(const edge_list& edges, variable_list& grads, const F& format_error) {
+void validate_outputs(
+    const edge_list& edges,
+    variable_list& grads,
+    const std::function<std::string(const std::string&)>& format_error) {
   if (grads.size() != edges.size()) {
     std::stringstream ss;
     ss << "invalid number of gradients - expected ";
@@ -622,6 +538,7 @@
     bool is_ready = false;
     auto& dependencies = task.base_->dependencies_;
     auto it = dependencies.find(next.function.get());
+
     if (it == dependencies.end()) {
       auto name = next.function->name();
       throw std::runtime_error(std::string("dependency not found for ") + name);
@@ -717,8 +634,6 @@
                      bool keep_graph,
                      bool create_graph,
                      const edge_list& outputs) -> variable_list {
-  std::call_once(start_threads_flag_, &Engine::start_threads, this);
-
   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
   validate_outputs(roots, const_cast<variable_list&>(inputs), [](const std::string& msg) {
     return msg;
@@ -729,15 +644,30 @@
   ClearCallbacks _cb_guard(final_callbacks_, post_callbacks_lock_);
 
   GraphTask graph_task(keep_graph, create_graph, worker_device == NO_DEVICE ? 0 : total_depth+1);
-  // Lock mutex while GraphTask is being set up
-  std::unique_lock<std::mutex> lock(graph_task.mutex_);
 
   // Now compute the dependencies for all executable functions and queue the root
   auto graph_root = std::make_shared<GraphRoot>(roots, inputs);
   compute_dependencies(graph_root.get(), graph_task);
+
   if (!outputs.empty()) {
     graph_task.init_to_execute(*graph_root, outputs);
   }
+  return execute_with_graph_task(graph_task, graph_root);
+}
+
+void Engine::enqueue_blocked_task_on_cpu(NodeTask task) {
+  std::call_once(start_threads_flag_, &Engine::start_threads, this);
+  ready_queue(at::kCPU).push(
+      std::move(task), /* incrementOutstandingTasks */ false);
+}
+
+variable_list Engine::execute_with_graph_task(
+    GraphTask& graph_task,
+    std::shared_ptr<Node> graph_root) {
+  std::call_once(start_threads_flag_, &Engine::start_threads, this);
+  // Lock mutex for GraphTask.
+  std::unique_lock<std::mutex> lock(graph_task.mutex_);
+
   ready_queue(at::kCPU).push(NodeTask(&graph_task, std::move(graph_root), InputBuffer(0)));
 
   // Not a worker
diff --git a/torch/csrc/autograd/engine.h b/torch/csrc/autograd/engine.h
index 8096832..b04132e 100644
--- a/torch/csrc/autograd/engine.h
+++ b/torch/csrc/autograd/engine.h
@@ -3,10 +3,12 @@
 // Engine implements backpropagation from output variables and their gradients
 // to "root" variables (variables created by the user with requires_grad=True).
 
+#include <ATen/ThreadLocalDebugInfo.h>
 #include <torch/csrc/WindowsTorchApiMacro.h>
-#include <torch/csrc/autograd/function.h>
-#include <torch/csrc/autograd/input_buffer.h>
 #include <torch/csrc/autograd/anomaly_mode.h>
+#include <torch/csrc/autograd/function.h>
+#include <torch/csrc/autograd/functions/basic_ops.h>
+#include <torch/csrc/autograd/input_buffer.h>
 
 #include <deque>
 #include <exception>
@@ -20,11 +22,111 @@
 
 namespace torch { namespace autograd {
 struct ReadyQueue;
-struct NodeTask;
-struct GraphTask;
 }} // namespace torch::autograd
 
 namespace torch { namespace autograd {
+
+void validate_outputs(
+    const edge_list& edges,
+    variable_list& grads,
+    const std::function<std::string(const std::string&)>& format_error);
+
+// NB: -1 indicates the CPU worker!
+static constexpr int NO_DEVICE = -2;
+
+// GraphTask holds metadata needed for a single execution of backward()
+struct GraphTask {
+  std::exception_ptr exception_;
+  // Indicates if an error occurred while executing any task.  When this is
+  // true, it signals all threads to stop executing.
+  std::atomic_bool has_error_;
+  std::atomic<uint64_t> outstanding_tasks_;
+  // It is safe to read grad_mode_ and keep_graph_ without synchronization
+  bool keep_graph_;
+  bool grad_mode_;
+
+  // To protect reads/writes to no_ready_, dependencies_ , captured_vars_ and
+  // exception_
+  std::mutex mutex_;
+  // Notified when a task finishes executing.  Check outstanding_tasks_ to see
+  // if all tasks are done.
+  std::condition_variable not_done_;
+  std::unordered_map<Node*, InputBuffer> not_ready_;
+  std::unordered_map<Node*, int> dependencies_;
+
+  struct ExecInfo {
+    struct Capture {
+      Capture(int input_idx, int output_idx)
+          : input_idx_(input_idx), output_idx_(output_idx) {}
+      int input_idx_; // within Node inputs
+      int output_idx_; // within the output vector of a GraphTask
+    };
+
+    bool should_execute() const {
+      return needed_ || captures_;
+    }
+
+    bool needed_ = false;
+    std::unique_ptr<std::vector<Capture>> captures_;
+  };
+  // Exec info has a bit complicated semantics. If it's empty, it means the task
+  // is run in a "default" mode, which means that all next_edges we encounter
+  // should get executed. If it's not empty, only functions that have an entry
+  // and this entry has needed == True should be executed. exec_info_.empty()
+  // means it's .backward(), otherwise it's .grad(). exec_info_ is safe to read
+  // without synchronization
+  std::unordered_map<Node*, ExecInfo> exec_info_;
+  std::vector<Variable> captured_vars_;
+  std::shared_ptr<at::ThreadLocalDebugInfoBase> debug_info_ =
+      at::getThreadLocalDebugInfo();
+  std::unordered_set<c10::Stream> leaf_streams;
+
+  void init_to_execute(Node& graph_root, const edge_list& outputs);
+
+  // The value of worker_device in the thread that created this task.
+  // See Note [Reentrant backwards]
+  // Safe to read owner_ and reentrant_depth_ without synchronizaton
+  int owner_;
+  // The number of parent graph tasks for this graph task
+  const int reentrant_depth_;
+
+  bool can_checkpoint() {
+    return exec_info_.empty();
+  }
+
+  GraphTask(bool keep_graph, bool grad_mode, int reentrant_depth)
+      : has_error_(false),
+        outstanding_tasks_(0),
+        keep_graph_(keep_graph),
+        grad_mode_(grad_mode),
+        owner_(NO_DEVICE),
+        reentrant_depth_(reentrant_depth) {}
+};
+
+struct NodeTask {
+  GraphTask* base_;
+  std::shared_ptr<Node> fn_;
+  // This buffer serves as an implicit "addition" node for all of the
+  // gradients flowing here.  Once all the dependencies are finished, we
+  // use the contents of this buffer to run the function.
+  InputBuffer inputs_;
+  // When worker receives a task with isShutdownTask = true, it will immediately
+  // exit. The engine sends a shutdown task to every queue upon its destruction.
+  bool isShutdownTask_;
+
+  int getReentrantDepth() const;
+
+  NodeTask(
+      GraphTask* base,
+      std::shared_ptr<Node> fn,
+      InputBuffer inputs,
+      bool isShutdownTask = false)
+      : base_(base),
+        fn_(std::move(fn)),
+        inputs_(std::move(inputs)),
+        isShutdownTask_(isShutdownTask) {}
+};
+
 // A single instance of this struct should be created through the whole process lifetime.
 // The worker thread creation logic and Engine's destructor rely on this.
 struct TORCH_API Engine {
@@ -45,6 +147,28 @@
       bool keep_graph,
       bool create_graph,
       const edge_list& outputs = {});
+
+  // Given a pre-populated GraphTask and GraphRoot, computes the backward pass
+  // for the graph. This API should only be used by internal autograd specific
+  // machinery and shouldn't be exposed to users in anyway.
+  variable_list execute_with_graph_task(
+      GraphTask& graph_task,
+      std::shared_ptr<Node> graph_root);
+
+  // Enqueues a blocked task for execution on the CPU thread. A blocked task is
+  // basically a task that isn't triggered automatically to be
+  // 'ready to execute' by the autograd engine. This task needs to be unblocked
+  // for execution via an external mechanism. This method assumes that
+  // the appropriate GraphTask has already been initialized appropriately.
+  // Another important part is that this does not increment 'outstanding_tasks_'
+  // in the appropriate GraphTask. It is assumed we've already done this before
+  // hand for this task (to ensure we block for its execution). This is useful
+  // in the distributed autograd case where we need to increment
+  // 'outstanding_tasks_' first to indicate the local autograd engine needs to
+  // wait for this task, but the task might actually be received later over the
+  // network for execution.
+  void enqueue_blocked_task_on_cpu(NodeTask task);
+
   virtual std::unique_ptr<AnomalyMetadata> make_anomaly_metadata() {
     return nullptr;
   }
diff --git a/torch/csrc/distributed/autograd/context/dist_autograd_container.cpp b/torch/csrc/distributed/autograd/context/dist_autograd_container.cpp
index 3302e1d..dfa7bae6 100644
--- a/torch/csrc/distributed/autograd/context/dist_autograd_container.cpp
+++ b/torch/csrc/distributed/autograd/context/dist_autograd_container.cpp
@@ -9,8 +9,10 @@
 constexpr int64_t kAutoIncrementMask = (1LL << kAutoIncrementBits) - 1;
 constexpr int kMaxWorkerId = 65535;
 
+constexpr int64_t kInvalidContextId = -1;
+
 // Each thread has a single autograd_context_id valid at any point in time.
-static thread_local int64_t current_context_id_ = -1;
+static thread_local int64_t current_context_id_ = kInvalidContextId;
 
 // Lock to ensure DistAutogradContainer is initialized only once.
 static std::mutex dist_container_init_lock_;
@@ -84,6 +86,10 @@
 }
 
 const DistAutogradContext& DistAutogradContainer::newContext() {
+  TORCH_CHECK(
+      current_context_id_ == kInvalidContextId,
+      "Already have an autograd context id for this thread.");
+
   std::lock_guard<std::mutex> guard(autograd_context_lock_);
   // Check for overflow into workerId_ section.
   TORCH_INTERNAL_ASSERT(next_context_id_ < max_id_);
@@ -100,7 +106,7 @@
 }
 
 bool DistAutogradContainer::hasValidContext() const {
-  return current_context_id_ != -1;
+  return current_context_id_ != kInvalidContextId;
 }
 
 DistAutogradContext& DistAutogradContainer::currentContext() {
@@ -128,7 +134,7 @@
 
   if (current_context_id_ == context_id) {
     // Reset the thread_local current context id, since it is no longer valid.
-    current_context_id_ = -1;
+    current_context_id_ = kInvalidContextId;
   }
 }
 
diff --git a/torch/csrc/distributed/autograd/context/dist_autograd_context.cpp b/torch/csrc/distributed/autograd/context/dist_autograd_context.cpp
index 0323ad8..2071ef5 100644
--- a/torch/csrc/distributed/autograd/context/dist_autograd_context.cpp
+++ b/torch/csrc/distributed/autograd/context/dist_autograd_context.cpp
@@ -5,11 +5,11 @@
 namespace distributed {
 namespace autograd {
 
-DistAutogradContext::DistAutogradContext(int64_t context_id)
-    : context_id_(context_id) {}
+DistAutogradContext::DistAutogradContext(int64_t contextId)
+    : contextId_(contextId) {}
 
-int64_t DistAutogradContext::context_id() const {
-  return context_id_;
+int64_t DistAutogradContext::contextId() const {
+  return contextId_;
 }
 
 void DistAutogradContext::addSendFunction(
@@ -48,6 +48,73 @@
   return recvAutogradFunctions_;
 }
 
+void DistAutogradContext::accumulateGrad(
+    const torch::autograd::Variable& variable,
+    const torch::Tensor& grad) {
+  TORCH_INTERNAL_ASSERT(grad.defined());
+  TORCH_INTERNAL_ASSERT(variable.requires_grad());
+
+  std::lock_guard<std::mutex> guard(lock_);
+  auto it = accumulatedGrads_.find(variable);
+  if (it != accumulatedGrads_.end()) {
+    // Accumulate multiple grads on the same variable.
+    it->value().add_(grad);
+  } else {
+    // First grad for this variable.
+    accumulatedGrads_.insert(variable, grad);
+  }
+}
+
+std::shared_ptr<torch::autograd::GraphTask> DistAutogradContext::
+    retrieveGraphTask() {
+  std::lock_guard<std::mutex> guard(lock_);
+  TORCH_INTERNAL_ASSERT(graphTask_);
+  return graphTask_;
+}
+
+void DistAutogradContext::setGraphTask(
+    std::shared_ptr<torch::autograd::GraphTask> graphTask) {
+  std::lock_guard<std::mutex> guard(lock_);
+  TORCH_INTERNAL_ASSERT(
+      !graphTask_,
+      "Cannot set GraphTask multiple times for the same autograd context");
+  graphTask_ = std::move(graphTask);
+}
+
+void DistAutogradContext::addOutstandingRpc(
+    const std::shared_ptr<rpc::FutureMessage>& futureMessage) {
+  std::lock_guard<std::mutex> guard(lock_);
+  outStandingRpcs_.push_back(futureMessage);
+}
+
+void DistAutogradContext::clearAndWaitForOutstandingRpcs() {
+  // Copy futures under lock, but wait for them outside the lock.
+  std::unique_lock<std::mutex> lock(lock_);
+  auto outStandingRpcs = std::move(outStandingRpcs_);
+  lock.unlock();
+
+  for (const auto& outStandingRpc : outStandingRpcs) {
+    outStandingRpc->wait();
+  }
+}
+
+std::shared_ptr<SendRpcBackward> DistAutogradContext::retrieveSendFunction(
+    int64_t autograd_message_id) {
+  std::lock_guard<std::mutex> guard(lock_);
+  auto it = sendAutogradFunctions_.find(autograd_message_id);
+  TORCH_CHECK(
+      it != sendAutogradFunctions_.end(),
+      "Could not find send function for autograd message id: ",
+      autograd_message_id);
+  return it->second;
+}
+
+const c10::Dict<torch::Tensor, torch::Tensor> DistAutogradContext::
+    getGradients() const {
+  std::lock_guard<std::mutex> guard(lock_);
+  return accumulatedGrads_;
+}
+
 } // namespace autograd
 } // namespace distributed
 } // namespace torch
diff --git a/torch/csrc/distributed/autograd/context/dist_autograd_context.h b/torch/csrc/distributed/autograd/context/dist_autograd_context.h
index 1ebd91e..f5bc4dd 100644
--- a/torch/csrc/distributed/autograd/context/dist_autograd_context.h
+++ b/torch/csrc/distributed/autograd/context/dist_autograd_context.h
@@ -1,21 +1,26 @@
 #pragma once
 
+#include <ATen/core/Dict.h>
+#include <torch/csrc/autograd/engine.h>
 #include <torch/csrc/distributed/autograd/functions/recvrpc_backward.h>
 #include <torch/csrc/distributed/autograd/functions/sendrpc_backward.h>
+#include <torch/csrc/distributed/rpc/future_message.h>
 #include <cstdint>
 
 namespace torch {
 namespace distributed {
 namespace autograd {
 
+class RecvRpcBackward;
+
 // DistAutogradContext which stores information for a single distributed
 // autograd pass on a worker.
 class TORCH_API DistAutogradContext {
  public:
-  explicit DistAutogradContext(int64_t context_id);
+  explicit DistAutogradContext(int64_t contextId);
 
   // Retrieves the autograd context id for this context.
-  int64_t context_id() const;
+  int64_t contextId() const;
 
   // Records a 'send' autograd function for this context with the provided
   // message id.
@@ -29,19 +34,51 @@
       std::shared_ptr<RecvRpcBackward>& func,
       int64_t autograd_message_id);
 
+  // Given an autograd_message_id, retrieve the appropriate send function.
+  std::shared_ptr<SendRpcBackward> retrieveSendFunction(
+      int64_t autograd_message_id);
+
+  // Return all send functions for this context.
   std::unordered_map<int64_t, std::shared_ptr<SendRpcBackward>> sendFunctions()
       const;
 
+  // Return all recv functions for this context.
   std::unordered_map<int64_t, std::shared_ptr<RecvRpcBackward>> recvFunctions()
       const;
 
+  // Adds a future message recording an outstanding RPC.
+  void addOutstandingRpc(
+      const std::shared_ptr<rpc::FutureMessage>& futureMessage);
+
+  // Returns all gradients.
+  const c10::Dict<torch::Tensor, torch::Tensor> getGradients() const;
+
   DistAutogradContext(const DistAutogradContext&) = delete;
   DistAutogradContext& operator=(const DistAutogradContext&) = delete;
   DistAutogradContext(DistAutogradContext&&) = delete;
   DistAutogradContext& operator=(DistAutogradContext&&) = delete;
 
  private:
-  const int64_t context_id_;
+  friend class DistEngine;
+
+  // Record that we would like to accumulate the provided gradient on the given
+  // variable.
+  void accumulateGrad(
+      const torch::autograd::Variable& variable,
+      const torch::Tensor& grad);
+
+  // Retrieve the GraphTask.
+  std::shared_ptr<torch::autograd::GraphTask> retrieveGraphTask();
+
+  // Set the appropriate graph task for the backward pass. Can be called only
+  // once.
+  void setGraphTask(std::shared_ptr<torch::autograd::GraphTask> graphTask);
+
+  // Waits for all outstanding RPCs for this context to finish and clears all
+  // outstanding rpcs held in this context. This should be called only once.
+  void clearAndWaitForOutstandingRpcs();
+
+  const int64_t contextId_;
 
   // Map from autograd_message_id to appropriate 'send' autograd function.
   std::unordered_map<int64_t, std::shared_ptr<SendRpcBackward>>
@@ -51,6 +88,19 @@
   std::unordered_map<int64_t, std::shared_ptr<RecvRpcBackward>>
       recvAutogradFunctions_;
 
+  // Gradients accumulated in this context so far. The key is the variable on
+  // which the gradient needs to be accumulated and the value is the gradient
+  // that needs to be accumulated on that variable..
+  c10::Dict<torch::Tensor, torch::Tensor> accumulatedGrads_;
+
+  // The autograd GraphTask for the backward pass on this node for this context.
+  std::shared_ptr<torch::autograd::GraphTask> graphTask_;
+
+  // List of futures for RPCs initiated by this node to propagate gradients to
+  // other nodes. The distributed autograd engine on this node can return
+  // successfully only if all these futures are done and are successfull.
+  std::vector<std::shared_ptr<rpc::FutureMessage>> outStandingRpcs_;
+
   // Lock to protect concurrent modification of the context.
   mutable std::mutex lock_;
 };
diff --git a/torch/csrc/distributed/autograd/engine/dist_engine.cpp b/torch/csrc/distributed/autograd/engine/dist_engine.cpp
new file mode 100644
index 0000000..5387604
--- /dev/null
+++ b/torch/csrc/distributed/autograd/engine/dist_engine.cpp
@@ -0,0 +1,270 @@
+#include <queue>
+
+#include <torch/csrc/autograd/functions/accumulate_grad.h>
+#include <torch/csrc/autograd/input_buffer.h>
+#include <torch/csrc/distributed/autograd/context/dist_autograd_container.h>
+#include <torch/csrc/distributed/autograd/engine/dist_engine.h>
+
+namespace torch {
+namespace distributed {
+namespace autograd {
+
+using torch::autograd::AccumulateGrad;
+using torch::autograd::edge_list;
+using torch::autograd::Engine;
+using torch::autograd::GraphRoot;
+using torch::autograd::GraphTask;
+using torch::autograd::Node;
+using torch::autograd::validate_outputs;
+using torch::autograd::variable_list;
+
+DistEngine::DistEngine()
+    : initializedContextIds_(), engine_(Engine::get_default_engine()) {}
+
+DistEngine& DistEngine::getInstance() {
+  static DistEngine engine;
+  return engine;
+}
+
+void DistEngine::validateRootsAndRetrieveEdges(
+    const variable_list& roots,
+    edge_list& rootEdges,
+    variable_list& grads) {
+  TORCH_CHECK(!roots.empty(), "No tensors provided for gradient computation.");
+  TORCH_INTERNAL_ASSERT(rootEdges.empty());
+  TORCH_INTERNAL_ASSERT(grads.empty());
+
+  // Verify roots are all scalar and require gradients.
+  for (const auto& root : roots) {
+    TORCH_CHECK(
+        root.requires_grad(), "requires_grad not set on: ", root.name());
+    TORCH_CHECK(
+        root.numel() == 1,
+        root.name(),
+        " is not a scalar, all roots need to be scalar");
+    TORCH_CHECK(
+        root.grad_fn(),
+        root.name(),
+        " does not have a valid gradient function.");
+
+    // Compute the root edges and generate the appropriate gradients.
+    rootEdges.push_back(root.gradient_edge());
+    grads.push_back(at::ones_like(root));
+  }
+
+  // Validate rootEdges and grads.
+  validate_outputs(
+      rootEdges, grads, [](const std::string& msg) { return msg; });
+}
+
+void DistEngine::computeDependencies(
+    DistAutogradContext& autogradContext,
+    const edge_list& rootEdges,
+    const variable_list& grads,
+    const std::shared_ptr<Node>& graphRoot,
+    edge_list& outputEdges) {
+  TORCH_INTERNAL_ASSERT(graphRoot, "graphRoot is null!");
+
+  // Build the graph task and graph root.
+  auto graphTask = std::make_shared<GraphTask>(
+      /* keep_graph */ false, /* create_graph */ false, /* depth */ 0);
+
+  // Run BFS to traverse the graph locally. The roots of the graph are
+  // GraphRoot and all send functions for this autograd context.
+  std::unordered_set<Node*> seen;
+  std::queue<Node*> queue;
+  queue.push(static_cast<Node*>(graphRoot.get()));
+
+  auto sendFunctions = autogradContext.sendFunctions();
+
+  // Add all the send functions to the queue as roots.
+  for (const auto& mapEntry : sendFunctions) {
+    // Increment 'outstanding_tasks_' for GraphTask for each send_function
+    // since we want the local autograd engine to wait for all of them.
+    graphTask->outstanding_tasks_++;
+    queue.push(mapEntry.second.get());
+  }
+
+  edge_list recvBackwardEdges;
+  // Traverse the graph.
+  auto& dependencies = graphTask->dependencies_;
+  while (!queue.empty()) {
+    auto fn = queue.front();
+    queue.pop();
+
+    for (const auto& edge : fn->next_edges()) {
+      if (auto nextFn = edge.function.get()) {
+        dependencies[nextFn] += 1;
+        const bool wasInserted = seen.insert(nextFn).second;
+        if (wasInserted) {
+          // Seeing this function for the first time.
+          queue.push(nextFn);
+
+          if (nextFn->next_edges().empty()) {
+            TORCH_INTERNAL_ASSERT(
+                dynamic_cast<AccumulateGrad*>(nextFn) ||
+                dynamic_cast<RecvRpcBackward*>(nextFn));
+            // We have found a leaf node which should be either AccumulateGrad
+            // or RecvRpcBackward. Record the function
+            // to ensure we don't execute it and instead accumulate the grads on
+            // the autograd context. These functions would be passed in as the
+            // 'outputs' parameter of the vanilla autograd engine.
+
+            // We don't accumulate any grads in the context for RecvRpcBackward.
+            // RecvRpcBackward is added as an output edge to indicate it is a
+            // leaf node and this helps in properly computing dependencies for
+            // the local autograd graph. Putting RecvRpcBackward in
+            // 'outputEdges' means that this function needs to be executed
+            // (inline with our assumption for FAST mode that all send/recv
+            // functions are valid in the backward pass), and as a result all of
+            //  its ancestors need to be executed as well.
+            if (dynamic_cast<RecvRpcBackward*>(nextFn)) {
+              recvBackwardEdges.emplace_back(edge);
+            }
+            outputEdges.emplace_back(edge);
+          }
+        }
+      }
+    }
+  }
+
+  // Now lets compute which functions need to be executed. The algorithm is as
+  // follows:
+  // 1. Create a dummy GraphRoot which points to all 'send' functions for this
+  //    context and the original graphRoot. Run 'init_to_execute' with the
+  //    outputEdges and the dummy GraphRoot. This ensures we mark
+  //    appropriate functions as needed if they are reachable only from a
+  //    specific 'send' function locally and not necessarily from the provided
+  //    roots.
+  // 2. For all edges in 'outputEdges' which point to 'RecvRpcBackward', mark
+  //    those functions as needed for execution. The reason for this is that
+  //    'init_to_execute', will mark these as not needed. But 'RecvRpcBackward'
+  //    is unique in the sense that we use it as a leaf node in graph to compute
+  //    needed execution accurately, but unlike AccumulateGrad, we do need to
+  //    execute this function.
+  if (!outputEdges.empty()) {
+    // Compute 'needed execution' starting from all 'send' functions and the
+    // original graphRoot.
+    edge_list edges;
+    // Create some dummy edges (input_nr not important for init_to_execute).
+    for (const auto& mapEntry : sendFunctions) {
+      edges.emplace_back(mapEntry.second, 0);
+    }
+
+    // Add the original graphRoot as an edge.
+    edges.emplace_back(graphRoot, 0);
+
+    // Create a dummy GraphRoot and run init_to_execute with it.
+    GraphRoot dummyRoot(edges, {});
+    graphTask->init_to_execute(dummyRoot, outputEdges);
+
+    // Mark all 'RecvRPCBackward' as needing execution.
+    for (const auto& recvBackwardEdge : recvBackwardEdges) {
+      graphTask->exec_info_[recvBackwardEdge.function.get()].needed_ = true;
+    }
+  }
+
+  // Let autograd context take ownership of the GraphTask.
+  autogradContext.setGraphTask(std::move(graphTask));
+}
+
+void DistEngine::runEngineAndAccumulateGradients(
+    DistAutogradContext& autogradContext,
+    const std::shared_ptr<Node>& graphRoot,
+    const edge_list& outputEdges) {
+  // Kick off autograd computation with the root node and retrieve all the
+  // gradients.
+  // TODO: make this non-blocking
+  // (https://github.com/pytorch/pytorch/issues/26359)
+  variable_list grads = engine_.execute_with_graph_task(
+      *autogradContext.retrieveGraphTask(), graphRoot);
+
+  // Accumulate all the gradients in the context.
+  TORCH_INTERNAL_ASSERT(grads.size() == outputEdges.size());
+  for (size_t i = 0; i < grads.size(); i++) {
+    // It is possible that the grad is not defined since a separate invocation
+    // of the autograd engine on the same node might actually compute this
+    // gradient.
+    // Also accumulate grads only for AccumulateGrad function.
+    if (grads[i].defined() &&
+        dynamic_cast<AccumulateGrad*>(outputEdges[i].function.get())) {
+      auto& variable =
+          std::static_pointer_cast<AccumulateGrad>(outputEdges[i].function)
+              ->variable;
+      autogradContext.accumulateGrad(variable, grads[i]);
+    }
+  }
+}
+
+void DistEngine::executeSendFunction(
+    DistAutogradContext& autogradContext,
+    const std::shared_ptr<Node>& sendFunction) {
+  std::unique_lock<std::mutex> lock(initializedContextIdsLock_);
+  if (initializedContextIds_.find(autogradContext.contextId()) ==
+      initializedContextIds_.end()) {
+    edge_list outputEdges;
+    // Pass in a dummy graphRoot since all send functions are the roots.
+    auto dummyRoot = std::make_shared<GraphRoot>(edge_list(), variable_list());
+    computeDependencies(autogradContext, {}, {}, dummyRoot, outputEdges);
+
+    // Mark the autograd context id as initialized and unlock.
+    initializedContextIds_.insert(autogradContext.contextId());
+    lock.unlock();
+
+    // Enqueue the current send function.
+    auto graphTask = autogradContext.retrieveGraphTask();
+    engine_.enqueue_blocked_task_on_cpu(torch::autograd::NodeTask(
+        graphTask.get(), sendFunction, torch::autograd::InputBuffer(0)));
+
+    // Run the autograd engine.
+    runEngineAndAccumulateGradients(autogradContext, dummyRoot, outputEdges);
+
+    // Wait for all of the outstanding rpcs to complete.
+    autogradContext.clearAndWaitForOutstandingRpcs();
+  } else {
+    lock.unlock();
+    auto graphTask = autogradContext.retrieveGraphTask();
+    engine_.enqueue_blocked_task_on_cpu(torch::autograd::NodeTask(
+        graphTask.get(), sendFunction, torch::autograd::InputBuffer(0)));
+  }
+}
+
+void DistEngine::execute(const variable_list& roots) {
+  // Get the current context, if exists. This will throw if we don't have a
+  // valid context.
+  DistAutogradContext& autogradContext =
+      DistAutogradContainer::getInstance().currentContext();
+
+  // Perform initial pre-processing.
+  edge_list rootEdges;
+  variable_list grads;
+  validateRootsAndRetrieveEdges(roots, rootEdges, grads);
+
+  std::shared_ptr<Node> graphRoot =
+      std::make_shared<GraphRoot>(rootEdges, grads);
+  edge_list outputEdges;
+  // Compute dependencies locally, starting from all roots and all 'send'
+  // functions.
+  {
+    std::lock_guard<std::mutex> guard(initializedContextIdsLock_);
+    // Context should not have been intialized already.
+    TORCH_INTERNAL_ASSERT(
+        initializedContextIds_.find(autogradContext.contextId()) ==
+        initializedContextIds_.end());
+
+    computeDependencies(
+        autogradContext, rootEdges, grads, graphRoot, outputEdges);
+
+    // Mark the autograd context id as initialized.
+    initializedContextIds_.insert(autogradContext.contextId());
+  }
+
+  runEngineAndAccumulateGradients(autogradContext, graphRoot, outputEdges);
+
+  // Wait for all of the outstanding rpcs to complete.
+  autogradContext.clearAndWaitForOutstandingRpcs();
+}
+
+} // namespace autograd
+} // namespace distributed
+} // namespace torch
diff --git a/torch/csrc/distributed/autograd/engine/dist_engine.h b/torch/csrc/distributed/autograd/engine/dist_engine.h
new file mode 100644
index 0000000..ce06c18
--- /dev/null
+++ b/torch/csrc/distributed/autograd/engine/dist_engine.h
@@ -0,0 +1,96 @@
+#pragma once
+
+#include <mutex>
+#include <unordered_set>
+
+#include <torch/csrc/autograd/engine.h>
+#include <torch/csrc/autograd/function.h>
+#include <torch/csrc/autograd/functions/basic_ops.h>
+#include <torch/csrc/distributed/autograd/context/dist_autograd_context.h>
+
+namespace torch {
+namespace distributed {
+namespace autograd {
+
+// This is a singleton class responsible for running distributed backward
+// passes. This engine relies heavily on the vanilla autograd engine and tries
+// to re-use it as much as possible. This class is mostly responsible for the
+// distributed aspects of autograd and tries to hook into the autograd engine
+// where convenient.
+
+// Unlike the vanilla autograd engine, the distributed autograd engine
+// accumulates the gradients in the appropriate DistAutogradContext. This avoids
+// multiple trainer nodes stomping on each others gradients.
+class TORCH_API DistEngine {
+ public:
+  // Retrieve the singleton instance.
+  static DistEngine& getInstance();
+
+  // Given a list of root variables, start the distributed backwards pass from
+  // these variables and accumulate all the gradients in the current autograd
+  // context on each node. This method is used to kickoff distributed autograd
+  // on a single node.
+  void execute(const torch::autograd::variable_list& roots);
+
+  // Given a send function to execute in the autograd engine, ensures we compute
+  // dependencies once for this node and enqueues the send function for execute
+  // in the engine.
+  // This method is used to kick off the autograd computation on a node when it
+  // receives gradients from the corresponding 'recv' method on another node.
+  // The gradients are accumulated in the provided autograd context.
+  void executeSendFunction(
+      DistAutogradContext& autogradContext,
+      const std::shared_ptr<torch::autograd::Node>& sendFunction);
+
+ private:
+  // Make sure this is a singleton.
+  DistEngine();
+  ~DistEngine() = default;
+
+  DistEngine(const DistEngine&) = delete;
+  DistEngine& operator=(const DistEngine&) = delete;
+  DistEngine(DistEngine&&) = delete;
+  DistEngine& operator=(DistEngine&&) = delete;
+
+  // Validates the input roots for the backward computations and retrieves the
+  // appropriate root edges and corresponding gradients. Populates root_edges
+  // with the appropriate gradient edges and grads with the gradients for each
+  // edge.
+  void validateRootsAndRetrieveEdges(
+      const torch::autograd::variable_list& roots,
+      torch::autograd::edge_list& rootEdges,
+      torch::autograd::variable_list& grads);
+
+  // Given the autograd context, root edges and grads, we compute dependencies
+  // for the local node and fill out the provided GraphTask and GraphRoot with
+  // appropriate information for the local autograd engine.
+  // We also determine all leaf nodes(functions) in the graph and accumulate
+  // them in outputEdges.
+  void computeDependencies(
+      DistAutogradContext& context,
+      const torch::autograd::edge_list& rootEdges,
+      const torch::autograd::variable_list& grads,
+      const std::shared_ptr<torch::autograd::Node>& graphRoot,
+      torch::autograd::edge_list& outputEdges);
+
+  // Run the local autograd engine using the provided graphTask and graphRoot
+  // and accumulate the gradients part 'outputEdges' in the provided autograd
+  // context.
+  void runEngineAndAccumulateGradients(
+      DistAutogradContext& autogradContext,
+      const std::shared_ptr<torch::autograd::Node>& graphRoot,
+      const torch::autograd::edge_list& outputEdges);
+
+  // Set of autograd context_ids, which we have already initialized for
+  // distributed autograd on this node (e.g.: already computed dependencies)
+  std::unordered_set<int64_t> initializedContextIds_;
+
+  mutable std::mutex initializedContextIdsLock_;
+
+  // Reference to local autograd engine.
+  torch::autograd::Engine& engine_;
+};
+
+} // namespace autograd
+} // namespace distributed
+} // namespace torch
diff --git a/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp b/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp
index aeecbb4..cf8b765 100644
--- a/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp
+++ b/torch/csrc/distributed/autograd/functions/recvrpc_backward.cpp
@@ -1,24 +1,51 @@
 #include <torch/csrc/distributed/autograd/functions/recvrpc_backward.h>
 #include <ATen/core/functional.h>
+#include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.h>
+#include <torch/csrc/distributed/rpc/rpc_agent.h>
 
 namespace torch {
 namespace distributed {
 namespace autograd {
 
 using torch::autograd::Variable;
+using torch::autograd::variable_list;
 
-torch::autograd::variable_list RecvRpcBackward::apply(
-    torch::autograd::variable_list&& grads) {
+RecvRpcBackward::RecvRpcBackward(
+    const AutogradMetadata& autogradMetadata,
+    DistAutogradContext& autogradContext,
+    rpc::worker_id_t fromWorkerId)
+    : autogradMetadata_(autogradMetadata),
+      autogradContext_(autogradContext),
+      fromWorkerId_(fromWorkerId) {}
+
+variable_list RecvRpcBackward::apply(variable_list&& grads) {
   std::vector<Variable> outputGrads;
-  for (const auto& grad : grads) {
+  for (size_t i = 0; i < grads.size(); i++) {
+    const auto& grad = grads[i];
     if (grad.defined()) {
       outputGrads.emplace_back(grad);
     } else {
-      outputGrads.emplace_back(at::zeros_like(grad));
+      // Put in zeros for a tensor with no grad.
+      outputGrads.emplace_back(input_metadata(i).zeros_like());
     }
   }
 
-  return outputGrads;
+  // Send the gradients over the wire and record the future in the autograd
+  // context.
+  PropagateGradientsReq gradCall(autogradMetadata_, outputGrads);
+
+  // Send the gradients over to the appropriate node (we don't need the worker
+  // name only the id, so use a placeholder "foo").
+  auto rpcAgent = rpc::RpcAgent::getDefaultRpcAgent();
+  auto futureMessage = rpcAgent->send(
+      rpcAgent->getWorkerInfo(fromWorkerId_), std::move(gradCall).toMessage());
+
+  // Record the future in the context.
+  autogradContext_.addOutstandingRpc(futureMessage);
+
+  // 'recv' function sends the gradients over the wire using RPC, it doesn't
+  // need to return anything for any downstream autograd function.
+  return variable_list();
 }
 
 } // namespace autograd
diff --git a/torch/csrc/distributed/autograd/functions/recvrpc_backward.h b/torch/csrc/distributed/autograd/functions/recvrpc_backward.h
index fbb2e9f..d5245d0 100644
--- a/torch/csrc/distributed/autograd/functions/recvrpc_backward.h
+++ b/torch/csrc/distributed/autograd/functions/recvrpc_backward.h
@@ -1,19 +1,40 @@
 #pragma once
 
 #include <torch/csrc/autograd/function.h>
+#include <torch/csrc/distributed/autograd/context/dist_autograd_context.h>
+#include <torch/csrc/distributed/autograd/rpc_messages/autograd_metadata.h>
+#include <torch/csrc/distributed/rpc/rpc_agent.h>
 
 namespace torch {
 namespace distributed {
 namespace autograd {
 
+class DistAutogradContext;
+
 // As part of our distributed autograd implementation, whenever we receive an
 // RPC from a node, we add a 'RecvRpcBackward' autograd function to the
 // autograd graph. This is more or less a placeholder function that is used to
 // pass gradients to the remote host during the backward pass. The inputs to the
 // RPC function are the inputs to this autograd function.
-struct TORCH_API RecvRpcBackward : public torch::autograd::Node {
+class TORCH_API RecvRpcBackward : public torch::autograd::Node {
+ public:
+  explicit RecvRpcBackward(
+      const AutogradMetadata& autogradMetadata,
+      DistAutogradContext& autogradContext,
+      rpc::worker_id_t fromWorkerId);
+
   torch::autograd::variable_list apply(
       torch::autograd::variable_list&& grads) override;
+
+ private:
+  const AutogradMetadata autogradMetadata_;
+
+  // Hold a reference to the autograd context.
+  DistAutogradContext& autogradContext_;
+
+  // The worker id from which the RPC was received. During the backward pass,
+  // we need to propagate the gradients to this workerId.
+  rpc::worker_id_t fromWorkerId_;
 };
 
 } // namespace autograd
diff --git a/torch/csrc/distributed/autograd/functions/sendrpc_backward.cpp b/torch/csrc/distributed/autograd/functions/sendrpc_backward.cpp
index eb615bc..211c3d7 100644
--- a/torch/csrc/distributed/autograd/functions/sendrpc_backward.cpp
+++ b/torch/csrc/distributed/autograd/functions/sendrpc_backward.cpp
@@ -5,16 +5,22 @@
 namespace autograd {
 
 torch::autograd::variable_list SendRpcBackward::apply(
-    torch::autograd::variable_list&& grads) {
+    torch::autograd::variable_list&& inputs) {
+  TORCH_INTERNAL_ASSERT(
+      inputs.empty(), "SendRpcBackward should receive no inputs");
+
   // Each grad variable should be valid!
-  for (const auto& grad : grads) {
-    TORCH_CHECK(
+  for (const auto& grad : grads_) {
+    TORCH_INTERNAL_ASSERT(
         grad.defined(), "BUG!: SendRpcBackward didn't receive valid gradients");
   }
 
   // Simply forwards the gradients over.
-  // TODO: Improve this as we build out more parts of distributed autograd.
-  return std::move(grads);
+  return std::move(grads_);
+}
+
+void SendRpcBackward::setGrads(const torch::autograd::variable_list& grads) {
+  grads_ = grads;
 }
 
 } // namespace autograd
diff --git a/torch/csrc/distributed/autograd/functions/sendrpc_backward.h b/torch/csrc/distributed/autograd/functions/sendrpc_backward.h
index e6713af..6203c9b 100644
--- a/torch/csrc/distributed/autograd/functions/sendrpc_backward.h
+++ b/torch/csrc/distributed/autograd/functions/sendrpc_backward.h
@@ -15,8 +15,18 @@
 // During the backward pass, this function is queued for execution in the
 // autograd engine which eventually runs the rest of the autograd graph.
 struct TORCH_API SendRpcBackward : public torch::autograd::Node {
+ public:
   torch::autograd::variable_list apply(
-      torch::autograd::variable_list&& grads) override;
+      torch::autograd::variable_list&& inputs) override;
+
+  // SendRpcBackward is actually the root of an autograd graph on the local
+  // node. As a result, it doesn't receive any 'inputs', but rather the RPC
+  // framework passes gradients over to this function to kickoff local autograd
+  // computation.
+  void setGrads(const torch::autograd::variable_list& grads);
+
+ private:
+  torch::autograd::variable_list grads_;
 };
 
 } // namespace autograd
diff --git a/torch/csrc/distributed/autograd/init.cpp b/torch/csrc/distributed/autograd/init.cpp
index fa214df..0f914c3 100644
--- a/torch/csrc/distributed/autograd/init.cpp
+++ b/torch/csrc/distributed/autograd/init.cpp
@@ -1,5 +1,6 @@
 #include <torch/csrc/autograd/python_cpp_function.h>
 #include <torch/csrc/distributed/autograd/context/dist_autograd_container.h>
+#include <torch/csrc/distributed/autograd/engine/dist_engine.h>
 #include <torch/csrc/jit/pybind_utils.h>
 #include <torch/csrc/python_headers.h>
 #include <torch/csrc/utils/object_ptr.h>
@@ -28,7 +29,7 @@
       shared_ptr_class_<DistAutogradContext>(module, "DistAutogradContext")
           .def(
               "_context_id",
-              &DistAutogradContext::context_id,
+              &DistAutogradContext::contextId,
               py::call_guard<py::gil_scoped_release>())
           .def(
               "_recv_functions",
@@ -61,9 +62,12 @@
       },
       py::return_value_policy::reference);
 
-  module.def("_release_context", [](int64_t context_id) {
-    return DistAutogradContainer::getInstance().releaseContext(context_id);
-  });
+  module.def(
+      "_release_context",
+      [](int64_t context_id) {
+        return DistAutogradContainer::getInstance().releaseContext(context_id);
+      },
+      py::call_guard<py::gil_scoped_release>());
 
   module.def("_get_max_id", []() {
     return DistAutogradContainer::getInstance().getMaxId();
@@ -83,8 +87,26 @@
       },
       py::return_value_policy::reference);
 
-  module.def("_init", [](int64_t worker_id) {
-    DistAutogradContainer::init(worker_id);
+  module.def(
+      "_init",
+      [](int64_t worker_id) { DistAutogradContainer::init(worker_id); },
+      py::call_guard<py::gil_scoped_release>());
+
+  module.def(
+      "_backward",
+      [](const std::vector<torch::Tensor>& roots) {
+        torch::autograd::variable_list variables;
+        for (const auto& root : roots) {
+          variables.emplace_back(root);
+        }
+        DistEngine::getInstance().execute(variables);
+      },
+      py::call_guard<py::gil_scoped_release>());
+
+  module.def("get_gradients", [](int64_t contextId) {
+    const auto& autogradContext =
+        DistAutogradContainer::getInstance().retrieveContext(contextId);
+    return torch::jit::toPyObject(IValue(autogradContext.getGradients()));
   });
 
   Py_RETURN_TRUE;
diff --git a/torch/csrc/distributed/autograd/rpc_messages/autograd_metadata.cpp b/torch/csrc/distributed/autograd/rpc_messages/autograd_metadata.cpp
new file mode 100644
index 0000000..6982e23
--- /dev/null
+++ b/torch/csrc/distributed/autograd/rpc_messages/autograd_metadata.cpp
@@ -0,0 +1,15 @@
+#include <torch/csrc/distributed/autograd/rpc_messages/autograd_metadata.h>
+
+namespace torch {
+namespace distributed {
+namespace autograd {
+
+AutogradMetadata::AutogradMetadata(
+    int64_t autogradContextId_,
+    int64_t autogradMessageId_)
+    : autogradContextId(autogradContextId_),
+      autogradMessageId(autogradMessageId_) {}
+
+} // namespace autograd
+} // namespace distributed
+} // namespace torch
diff --git a/torch/csrc/distributed/autograd/rpc_messages/autograd_metadata.h b/torch/csrc/distributed/autograd/rpc_messages/autograd_metadata.h
new file mode 100644
index 0000000..41067ca
--- /dev/null
+++ b/torch/csrc/distributed/autograd/rpc_messages/autograd_metadata.h
@@ -0,0 +1,25 @@
+#pragma once
+
+#include <torch/csrc/WindowsTorchApiMacro.h>
+#include <cstdint>
+
+namespace torch {
+namespace distributed {
+namespace autograd {
+
+// This structure represents autograd metadata that we need to pass across
+// different nodes when we call an RPC which needs autograd computation.
+struct TORCH_API AutogradMetadata {
+  AutogradMetadata(int64_t autogradContextId, int64_t autogradMessageId);
+
+  // autogradContextId_ is a globally unique integer that identifies a
+  // particular distributed autograd pass.
+  int64_t autogradContextId;
+  // autogradMessageId_ is a globally unique integer that identifies a pair
+  // of send/recv autograd functions.
+  int64_t autogradMessageId;
+};
+
+} // namespace autograd
+} // namespace distributed
+} // namespace torch
diff --git a/torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.cpp b/torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.cpp
new file mode 100644
index 0000000..237d4d0
--- /dev/null
+++ b/torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.cpp
@@ -0,0 +1,81 @@
+#include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.h>
+#include <torch/csrc/jit/pickle.h>
+
+namespace torch {
+namespace distributed {
+namespace autograd {
+
+using rpc::Message;
+using rpc::MessageType;
+using torch::autograd::Variable;
+
+PropagateGradientsReq::PropagateGradientsReq(
+    const AutogradMetadata& autogradMetadata,
+    std::vector<Variable> grads)
+    : autogradMetadata_(autogradMetadata), grads_(std::move(grads)) {}
+
+Message PropagateGradientsReq::toMessage() && {
+  std::vector<at::IValue> ivalues;
+  // Add all the grad tensors.
+  for (const auto& grad : grads_) {
+    ivalues.emplace_back(grad);
+  }
+
+  // Now add autograd metadata.
+  ivalues.emplace_back(autogradMetadata_.autogradContextId);
+  ivalues.emplace_back(autogradMetadata_.autogradMessageId);
+
+  // Now pickle using JIT pickler.
+  std::vector<torch::Tensor> tensorTable;
+  std::vector<char> payload =
+      jit::pickle(c10::ivalue::Tuple::create(std::move(ivalues)), &tensorTable);
+
+  return Message(
+      std::move(payload),
+      std::move(tensorTable),
+      MessageType::BACKWARD_AUTOGRAD_REQ);
+}
+
+std::unique_ptr<PropagateGradientsReq> PropagateGradientsReq::fromMessage(
+    const Message& message) {
+  // Unpickle the message and retrieve tupleElements.
+  auto payload = static_cast<const char*>(message.payload().data());
+  auto payload_size = message.payload().size();
+  IValue tuple =
+      jit::unpickle(payload, payload_size, nullptr, &message.tensors());
+  std::vector<at::IValue> tupleElements = tuple.toTuple()->elements();
+
+  // Build PropagateGradientsReq.
+  TORCH_INTERNAL_ASSERT(tupleElements.size() >= 2);
+
+  // Build AutogradMetadata.
+  int64_t autogradContextId, autogradMessageId;
+  autogradMessageId = tupleElements.back().toInt();
+  tupleElements.pop_back();
+  autogradContextId = tupleElements.back().toInt();
+  tupleElements.pop_back();
+
+  AutogradMetadata autogradMetadata(autogradContextId, autogradMessageId);
+
+  // Retrieve the gradient tensors.
+  std::vector<Variable> grads(tupleElements.size());
+  for (size_t i = 0; i < tupleElements.size(); i++) {
+    grads[i] = tupleElements[i].toTensor();
+  }
+
+  return std::unique_ptr<PropagateGradientsReq>(
+      new PropagateGradientsReq(autogradMetadata, grads));
+}
+
+const AutogradMetadata& PropagateGradientsReq::getAutogradMetadata() {
+  return autogradMetadata_;
+}
+
+const std::vector<torch::autograd::Variable>& PropagateGradientsReq::
+    getGrads() {
+  return grads_;
+}
+
+} // namespace autograd
+} // namespace distributed
+} // namespace torch
diff --git a/torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.h b/torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.h
new file mode 100644
index 0000000..063eda8
--- /dev/null
+++ b/torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.h
@@ -0,0 +1,37 @@
+#pragma once
+
+#include <torch/csrc/distributed/autograd/rpc_messages/autograd_metadata.h>
+#include <torch/csrc/distributed/rpc/message.h>
+#include <torch/csrc/distributed/rpc/rpc_command_base.h>
+#include <vector>
+
+namespace torch {
+namespace distributed {
+namespace autograd {
+
+// Used to propagate gradients from one node to another during a distributed
+// backwards pass. This RPC call is invoked when we hit a `recv` autograd
+// function during backward pass execution.
+class TORCH_API PropagateGradientsReq : public rpc::RpcCommandBase {
+ public:
+  PropagateGradientsReq(
+      const AutogradMetadata& autogradMetadata,
+      std::vector<torch::autograd::Variable> grads);
+
+  const AutogradMetadata& getAutogradMetadata();
+
+  const std::vector<torch::autograd::Variable>& getGrads();
+
+  // Serialization and deserialization methods.
+  rpc::Message toMessage() && override;
+  static std::unique_ptr<PropagateGradientsReq> fromMessage(
+      const rpc::Message& message);
+
+ private:
+  AutogradMetadata autogradMetadata_;
+  std::vector<torch::autograd::Variable> grads_;
+};
+
+} // namespace autograd
+} // namespace distributed
+} // namespace torch
diff --git a/torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_resp.cpp b/torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_resp.cpp
new file mode 100644
index 0000000..3fa65d2
--- /dev/null
+++ b/torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_resp.cpp
@@ -0,0 +1,18 @@
+#include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_resp.h>
+
+namespace torch {
+namespace distributed {
+namespace autograd {
+
+rpc::Message PropagateGradientsResp::toMessage() && {
+  return rpc::Message({}, {}, rpc::MessageType::BACKWARD_AUTOGRAD_RESP);
+}
+
+std::unique_ptr<PropagateGradientsResp> PropagateGradientsResp::fromMessage(
+    const rpc::Message& message) {
+  return std::unique_ptr<PropagateGradientsResp>();
+}
+
+} // namespace autograd
+} // namespace distributed
+} // namespace torch
diff --git a/torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_resp.h b/torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_resp.h
new file mode 100644
index 0000000..8459208
--- /dev/null
+++ b/torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_resp.h
@@ -0,0 +1,24 @@
+#pragma once
+
+#include <torch/csrc/distributed/rpc/message.h>
+#include <torch/csrc/distributed/rpc/rpc_command_base.h>
+
+namespace torch {
+namespace distributed {
+namespace autograd {
+
+// Response for the PropagateGradients call. Currently, this class is mostly
+// just a placeholder and sends an empty message over the wire. The purpose of
+// this RPC command is to indicate whether or not the PropagateGradientsReq call
+// was successfully or not.
+class TORCH_API PropagateGradientsResp : public rpc::RpcCommandBase {
+ public:
+  PropagateGradientsResp() = default;
+  rpc::Message toMessage() && override;
+  static std::unique_ptr<PropagateGradientsResp> fromMessage(
+      const rpc::Message& message);
+};
+
+} // namespace autograd
+} // namespace distributed
+} // namespace torch
diff --git a/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp b/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp
new file mode 100644
index 0000000..22c2cf3
--- /dev/null
+++ b/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.cpp
@@ -0,0 +1,176 @@
+#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
+#include <c10/util/C++17.h>
+#include <torch/csrc/distributed/rpc/utils.h>
+#include <torch/csrc/jit/pickle.h>
+#include <torch/csrc/utils/byte_order.h>
+
+namespace torch {
+namespace distributed {
+namespace autograd {
+
+using rpc::Message;
+using rpc::MessageType;
+using rpc::RpcCommandBase;
+using rpc::worker_id_t;
+
+RpcWithAutograd::RpcWithAutograd(
+    worker_id_t fromWorkerId,
+    MessageType messageType,
+    const AutogradMetadata& autogradMetadata,
+    std::unique_ptr<RpcCommandBase> wrappedRpc)
+    : fromWorkerId_(fromWorkerId),
+      messageType_(messageType),
+      autogradMetadata_(autogradMetadata) {
+  TORCH_INTERNAL_ASSERT(wrappedRpc != nullptr, "wrappedRpc cannot be null!");
+  TORCH_INTERNAL_ASSERT(
+      messageType_ == MessageType::FORWARD_AUTOGRAD_REQ ||
+      messageType_ == MessageType::FORWARD_AUTOGRAD_RESP);
+  wrappedMessage_ = std::move(*wrappedRpc).toMessage();
+  tensors_ = wrappedMessage_.tensors();
+  wrappedMessageType_ = wrappedMessage_.type();
+}
+
+RpcWithAutograd::RpcWithAutograd(
+    worker_id_t fromWorkerId,
+    MessageType messageType,
+    const AutogradMetadata& autogradMetadata,
+    std::unique_ptr<RpcCommandBase> wrappedRpc,
+    MessageType wrappedMessageType,
+    std::vector<torch::Tensor> tensors)
+    : fromWorkerId_(fromWorkerId),
+      messageType_(messageType),
+      autogradMetadata_(autogradMetadata),
+      wrappedRpc_(std::move(wrappedRpc)),
+      wrappedMessageType_(wrappedMessageType),
+      tensors_(std::move(tensors)) {
+  TORCH_INTERNAL_ASSERT(wrappedRpc_ != nullptr, "wrappedRpc cannot be null!");
+  TORCH_INTERNAL_ASSERT(
+      messageType_ == MessageType::FORWARD_AUTOGRAD_REQ ||
+      messageType_ == MessageType::FORWARD_AUTOGRAD_RESP);
+}
+
+Message RpcWithAutograd::toMessage() && {
+  auto messageId = wrappedMessage_.id();
+  auto messageType = wrappedMessage_.type();
+
+  auto payload = std::move(wrappedMessage_).movePayload();
+  TORCH_INTERNAL_ASSERT(!payload.empty());
+
+  std::vector<at::IValue> ivalues{messageType,
+                                  autogradMetadata_.autogradContextId,
+                                  autogradMetadata_.autogradMessageId,
+                                  fromWorkerId_};
+
+  // Now pickle using JIT pickler.
+  std::vector<torch::Tensor> tensorTable;
+  std::vector<char> additionalPayload =
+      jit::pickle(c10::ivalue::Tuple::create(std::move(ivalues)), &tensorTable);
+
+  // We shouldn't have any tensors!
+  TORCH_INTERNAL_ASSERT(tensorTable.empty());
+
+  // Append the payload.
+  payload.insert(
+      payload.end(), additionalPayload.begin(), additionalPayload.end());
+
+  // Add size of the additional payload.
+  int64_t indexToWrite = payload.size();
+  payload.resize(payload.size() + sizeof(int64_t));
+  const int64_t additionalPayloadSize = additionalPayload.size();
+  torch::utils::THP_encodeInt64Buffer(
+      reinterpret_cast<uint8_t*>(payload.data()) + indexToWrite,
+      &additionalPayloadSize,
+      torch::utils::THPByteOrder::THP_BIG_ENDIAN,
+      1);
+
+  return Message(
+      std::move(payload), std::move(tensors_), messageType_, messageId);
+}
+
+std::unique_ptr<RpcWithAutograd> RpcWithAutograd::fromMessage(
+    const Message& message) {
+  MessageType originalMessageType = message.type();
+  TORCH_INTERNAL_ASSERT(
+      MessageType::FORWARD_AUTOGRAD_REQ == originalMessageType ||
+      MessageType::FORWARD_AUTOGRAD_RESP == originalMessageType);
+
+  std::vector<torch::Tensor> tensors = message.tensors();
+  int64_t messageId = message.id();
+  // Decode message type, autograd context id, autograd message id and worker
+  // id from which we received this message.
+  auto payload = message.payload();
+
+  // Read the autograd payload remove it from the payload.
+  int64_t autogradPayLoadSize;
+  size_t indexToRead = payload.size() - sizeof(int64_t);
+  TORCH_INTERNAL_ASSERT(indexToRead >= 0);
+  torch::utils::THP_decodeInt64Buffer(
+      &autogradPayLoadSize,
+      reinterpret_cast<uint8_t*>(payload.data()) + indexToRead,
+      torch::utils::THPByteOrder::THP_BIG_ENDIAN,
+      1);
+  payload.resize(indexToRead);
+
+  // Now read the entire autograd payload and unpickle.
+  TORCH_INTERNAL_ASSERT(payload.size() > autogradPayLoadSize)
+  auto autogradPayLoadBegin =
+      static_cast<const char*>(message.payload().data()) + payload.size() -
+      autogradPayLoadSize;
+  std::vector<torch::Tensor> tensorTable;
+  IValue tuple = jit::unpickle(
+      autogradPayLoadBegin, autogradPayLoadSize, nullptr, &tensorTable);
+  std::vector<at::IValue> tupleElements = tuple.toTuple()->elements();
+
+  // Gather all the fields.
+  TORCH_INTERNAL_ASSERT(tupleElements.size() == 4);
+  MessageType wrappedMessageType =
+      static_cast<MessageType>(tupleElements[0].toInt());
+  AutogradMetadata autogradMetadata(
+      tupleElements[1].toInt(), tupleElements[2].toInt());
+  worker_id_t workerId = tupleElements[3].toInt();
+  payload.resize(payload.size() - autogradPayLoadSize);
+
+  // Create new message type and build wrapped RPC.
+  Message wrappedMessage(
+      std::move(payload), std::move(tensors), wrappedMessageType, messageId);
+
+  std::unique_ptr<RpcCommandBase> wrappedRpc;
+  if (originalMessageType == MessageType::FORWARD_AUTOGRAD_REQ) {
+    wrappedRpc = deserializeRequest(wrappedMessage);
+  } else {
+    wrappedRpc = deserializeResponse(wrappedMessage);
+  }
+
+  return c10::guts::make_unique<RpcWithAutograd>(
+      workerId,
+      originalMessageType,
+      autogradMetadata,
+      std::move(wrappedRpc),
+      wrappedMessageType,
+      wrappedMessage.tensors());
+}
+
+std::vector<torch::Tensor>& RpcWithAutograd::tensors() {
+  return tensors_;
+}
+
+const AutogradMetadata& RpcWithAutograd::autogradMetadata() const {
+  return autogradMetadata_;
+}
+
+RpcCommandBase& RpcWithAutograd::wrappedRpc() {
+  TORCH_INTERNAL_ASSERT(wrappedRpc_ != nullptr, "wrappedRpc cannot be null!");
+  return *wrappedRpc_;
+}
+
+MessageType RpcWithAutograd::wrappedMessageType() const {
+  return wrappedMessageType_;
+}
+
+rpc::worker_id_t RpcWithAutograd::fromWorkerId() const {
+  return fromWorkerId_;
+}
+
+} // namespace autograd
+} // namespace distributed
+} // namespace torch
diff --git a/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h b/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h
new file mode 100644
index 0000000..e5b6c94
--- /dev/null
+++ b/torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h
@@ -0,0 +1,76 @@
+#pragma once
+
+#include <torch/csrc/distributed/autograd/rpc_messages/autograd_metadata.h>
+#include <torch/csrc/distributed/rpc/rpc_agent.h>
+#include <torch/csrc/distributed/rpc/rpc_command_base.h>
+
+namespace torch {
+namespace distributed {
+namespace autograd {
+
+// Represents an RPC that includes autograd information. This class basically
+// wraps another `RpcCommandBase` object which represents the actual RPC and has
+// additional autograd information associated with that RPC.
+class TORCH_API RpcWithAutograd final : public rpc::RpcCommandBase {
+ public:
+  // Used when we are sending an RPC over the wire.
+  RpcWithAutograd(
+      rpc::worker_id_t fromWorkerId,
+      rpc::MessageType messageType,
+      const AutogradMetadata& autogradMetadata,
+      std::unique_ptr<rpc::RpcCommandBase> wrappedRpc);
+
+  // Used when receiving an RPC over the wire.
+  RpcWithAutograd(
+      rpc::worker_id_t fromWorkerId,
+      rpc::MessageType messageType,
+      const AutogradMetadata& autogradMetadata,
+      std::unique_ptr<rpc::RpcCommandBase> wrappedRpc,
+      rpc::MessageType wrappedMessageType,
+      std::vector<torch::Tensor> tensors);
+
+  rpc::Message toMessage() && override;
+
+  static std::unique_ptr<RpcWithAutograd> fromMessage(
+      const rpc::Message& message);
+
+  // Retrieves tensors as part of this RPC, which need to be considered for
+  // autograd computations.
+  std::vector<torch::Tensor>& tensors();
+
+  const AutogradMetadata& autogradMetadata() const;
+
+  RpcCommandBase& wrappedRpc();
+
+  // Message type of the wrapped RPC.
+  rpc::MessageType wrappedMessageType() const;
+
+  // Retrieve the worker id from which the RPC originated.
+  rpc::worker_id_t fromWorkerId() const;
+
+ private:
+  // WorkerId from which this RPC originated. This is necessary for knowing
+  // which worker we need to contact during the backward pass.
+  rpc::worker_id_t fromWorkerId_;
+
+  // Message type for this call.
+  rpc::MessageType messageType_;
+
+  AutogradMetadata autogradMetadata_;
+  std::unique_ptr<RpcCommandBase> wrappedRpc_;
+
+  // Serialized message representing wrappedRpc_. Used mostly as a cache to
+  // avoid serializing the request twice.
+  rpc::Message wrappedMessage_;
+
+  // message type of the wrappedMessage, this is stored separately since
+  // wrappedMessage_ is not always guaranteed to be populated.
+  rpc::MessageType wrappedMessageType_;
+
+  // Tensors part of the wrappedRpc that need to be considered for autograd.
+  std::vector<torch::Tensor> tensors_;
+};
+
+} // namespace autograd
+} // namespace distributed
+} // namespace torch
diff --git a/torch/csrc/distributed/autograd/utils.cpp b/torch/csrc/distributed/autograd/utils.cpp
index 3eb19d4..524765d 100644
--- a/torch/csrc/distributed/autograd/utils.cpp
+++ b/torch/csrc/distributed/autograd/utils.cpp
@@ -3,6 +3,7 @@
 #include <torch/csrc/distributed/autograd/functions/recvrpc_backward.h>
 #include <torch/csrc/distributed/autograd/functions/sendrpc_backward.h>
 #include <torch/csrc/distributed/autograd/utils.h>
+#include <torch/csrc/distributed/rpc/rpc_agent.h>
 
 namespace torch {
 namespace distributed {
@@ -12,7 +13,7 @@
 
 void addSendRpcBackward(
     DistAutogradContext& autogradContext,
-    const torch::distributed::rpc::AutogradMetadata& autogradMetadata,
+    const AutogradMetadata& autogradMetadata,
     std::vector<torch::Tensor>& tensors) {
   // Attach the appropriate autograd edges.
   if (torch::autograd::compute_requires_grad(tensors)) {
@@ -31,20 +32,23 @@
 }
 
 DistAutogradContext* addRecvRpcBackward(
-    const torch::distributed::rpc::AutogradMetadata& autogradMetadata,
-    std::vector<torch::Tensor>& tensors) {
+    const AutogradMetadata& autogradMetadata,
+    std::vector<torch::Tensor>& tensors,
+    rpc::worker_id_t fromWorkerId) {
   if (torch::autograd::compute_requires_grad(tensors)) {
+    // Initialize autograd context if necessary.
+    auto& autogradContainer = DistAutogradContainer::getInstance();
+    DistAutogradContext& autogradContext = autogradContainer.getOrCreateContext(
+        autogradMetadata.autogradContextId);
+
     // Attach the tensors as inputs to the autograd function.
-    auto grad_fn = std::make_shared<RecvRpcBackward>();
+    auto grad_fn = std::make_shared<RecvRpcBackward>(
+        autogradMetadata, autogradContext, fromWorkerId);
     for (auto& tensor : tensors) {
       torch::autograd::set_history(tensor, grad_fn);
     }
 
     // Now update the autograd context with the necessary information.
-    auto& autogradContainer = DistAutogradContainer::getInstance();
-    // Initialize autograd context if necessary.
-    DistAutogradContext& autogradContext = autogradContainer.getOrCreateContext(
-        autogradMetadata.autogradContextId);
     autogradContext.addRecvFunction(
         grad_fn, autogradMetadata.autogradMessageId);
     return &autogradContext;
diff --git a/torch/csrc/distributed/autograd/utils.h b/torch/csrc/distributed/autograd/utils.h
index 3126f63..ab9dc3a 100644
--- a/torch/csrc/distributed/autograd/utils.h
+++ b/torch/csrc/distributed/autograd/utils.h
@@ -1,7 +1,7 @@
 #pragma once
 
 #include <torch/csrc/distributed/autograd/context/dist_autograd_context.h>
-#include <torch/csrc/distributed/rpc/rpc_with_autograd.h>
+#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
 
 namespace torch {
 namespace distributed {
@@ -15,7 +15,7 @@
 // autograd information for the recipient.
 TORCH_API void addSendRpcBackward(
     DistAutogradContext& autogradContext,
-    const torch::distributed::rpc::AutogradMetadata& autogradMetadata,
+    const AutogradMetadata& autogradMetadata,
     std::vector<torch::Tensor>& tensors);
 
 // This method is used to attach the 'recv' autograd function to the autograd
@@ -27,8 +27,9 @@
 // Returns a pointer to the autograd context created (nullptr in case of no
 // autograd information was needed.)
 TORCH_API DistAutogradContext* addRecvRpcBackward(
-    const torch::distributed::rpc::AutogradMetadata& autogradMetadata,
-    std::vector<torch::Tensor>& tensors);
+    const AutogradMetadata& autogradMetadata,
+    std::vector<torch::Tensor>& tensors,
+    rpc::worker_id_t fromWorkerId);
 
 } // namespace autograd
 } // namespace distributed
diff --git a/torch/csrc/distributed/rpc/future_message.cpp b/torch/csrc/distributed/rpc/future_message.cpp
index 92aef39..9df291a 100644
--- a/torch/csrc/distributed/rpc/future_message.cpp
+++ b/torch/csrc/distributed/rpc/future_message.cpp
@@ -7,6 +7,12 @@
 const Message& FutureMessage::wait() {
   std::unique_lock<std::mutex> lock(mutex_);
   finished_cv_.wait(lock, [this] { return completed_.load(); });
+
+  // Throw an exception if we encounter one.
+  if (message_.type() == MessageType::EXCEPTION) {
+    std::string err(message_.payload().begin(), message_.payload().end());
+    throw std::runtime_error(err);
+  }
   return message_;
 }
 
diff --git a/torch/csrc/distributed/rpc/init.cpp b/torch/csrc/distributed/rpc/init.cpp
index 8448ad3..27bbd90 100644
--- a/torch/csrc/distributed/rpc/init.cpp
+++ b/torch/csrc/distributed/rpc/init.cpp
@@ -103,12 +103,12 @@
           &ProcessGroupAgent::sync,
           py::call_guard<py::gil_scoped_release>());
 
-  module.def("_init_rref_context", [](std::shared_ptr<RpcAgent> agent) {
-    RRefContext::initInstance(std::move(agent));
+  module.def("_init_rpc_agent", [](std::shared_ptr<RpcAgent> agent) {
+    RpcAgent::setDefaultRpcAgent(std::move(agent));
   });
 
   module.def("_destroy_rref_context", []() {
-    RRefContext::getInstance()->destroyInstance();
+    RRefContext::getInstance().destroyInstance();
   });
 
   module.def(
diff --git a/torch/csrc/distributed/rpc/message.cpp b/torch/csrc/distributed/rpc/message.cpp
index 18a68b1..fc6907d 100644
--- a/torch/csrc/distributed/rpc/message.cpp
+++ b/torch/csrc/distributed/rpc/message.cpp
@@ -76,7 +76,8 @@
       MessageType::RREF_CHILD_ACCEPT == type_ ||
       MessageType::RREF_FORK_REQUEST == type_ ||
       // Autograd message
-      MessageType::MESSAGE_WITH_AUTOGRAD_REQ == type_;
+      MessageType::BACKWARD_AUTOGRAD_REQ == type_ ||
+      MessageType::FORWARD_AUTOGRAD_REQ == type_;
 }
 
 bool Message::isResponse() const {
@@ -87,7 +88,8 @@
       MessageType::EXCEPTION == type_ || // propagate back exceptions
       MessageType::RREF_ACK == type_ || // ret of other types
       // Autograd response
-      MessageType::MESSAGE_WITH_AUTOGRAD_RESP == type_;
+      MessageType::BACKWARD_AUTOGRAD_RESP == type_ ||
+      MessageType::FORWARD_AUTOGRAD_RESP == type_;
 }
 
 bool Message::isShutdown() const {
diff --git a/torch/csrc/distributed/rpc/message.h b/torch/csrc/distributed/rpc/message.h
index f9d1e7c..c83cab3 100644
--- a/torch/csrc/distributed/rpc/message.h
+++ b/torch/csrc/distributed/rpc/message.h
@@ -31,13 +31,17 @@
   RREF_ACK = 13, // ACK to internal RRef messages
 
   // Messages with autograd info
-  MESSAGE_WITH_AUTOGRAD_REQ = 14,
-  MESSAGE_WITH_AUTOGRAD_RESP = 15,
+  FORWARD_AUTOGRAD_REQ = 14,
+  FORWARD_AUTOGRAD_RESP = 15,
+
+  // Messages to propagate gradients on the backward pass.
+  BACKWARD_AUTOGRAD_REQ = 16,
+  BACKWARD_AUTOGRAD_RESP = 17,
 
   // Other internal message types
-  SHUTDOWN = 16,
-  EXCEPTION = 17,
-  UNKNOWN = 18
+  SHUTDOWN = 50,
+  EXCEPTION = 55,
+  UNKNOWN = 60
 };
 
 // A message to be sent/received by an RpcAgent.
diff --git a/torch/csrc/distributed/rpc/py_rref.cpp b/torch/csrc/distributed/rpc/py_rref.cpp
index 54168bf..ccf03b8 100644
--- a/torch/csrc/distributed/rpc/py_rref.cpp
+++ b/torch/csrc/distributed/rpc/py_rref.cpp
@@ -79,7 +79,7 @@
   TORCH_CHECK(
       rref_->isOwner(),
       "Cannot call localValue() on a non-local reference. Call it on ",
-      RRefContext::getInstance()->getWorkerName());
+      RRefContext::getInstance().getWorkerName());
 
   if (rref_->isPyObj()) {
     const py::object& value =
@@ -109,7 +109,7 @@
   // install the dispatch table only when there are indeed RPC activities. As
   // a counter example, checkpointing a model with RRefs should not trigger
   // forks to be added as a fork or a child.
-  auto rfd = ctx->prepareChildFork(rref_);
+  auto rfd = ctx.prepareChildFork(rref_);
   return py::make_tuple(rfd.toPyTuple(), rref_->isPyObj());
 }
 
@@ -121,12 +121,12 @@
   std::shared_ptr<RRef> rref = nullptr;
   bool isPyObj = t[TYPE_IDX].cast<bool>();
   if (isPyObj) {
-    rref = ctx->getOrCreateRRef<py::object>(rfd);
+    rref = ctx.getOrCreateRRef<py::object>(rfd);
   } else {
-    rref = ctx->getOrCreateRRef<IValue>(rfd);
+    rref = ctx.getOrCreateRRef<IValue>(rfd);
   }
 
-  ctx->notifyOwnerAndParentOfFork(rfd.forkId_, rfd.parent_, rref);
+  ctx.notifyOwnerAndParentOfFork(rfd.forkId_, rfd.parent_, rref);
   return PyRRef(std::move(rref));
 }
 
diff --git a/torch/csrc/distributed/rpc/python_functions.cpp b/torch/csrc/distributed/rpc/python_functions.cpp
index 52cae25..7f9f9b2 100644
--- a/torch/csrc/distributed/rpc/python_functions.cpp
+++ b/torch/csrc/distributed/rpc/python_functions.cpp
@@ -65,7 +65,7 @@
   RRefContext::handleException(message);
   auto rr = RemoteRet::fromMessage(message);
   auto& ctx = RRefContext::getInstance();
-  ctx->delPendingUser(rr->forkId());
+  ctx.delPendingUser(rr->forkId());
 }
 
 } // namespace
@@ -92,12 +92,14 @@
       return PythonRpcHandler::getInstance().loadPythonUDFResult(
           resp.pickledPayload(), resp.tensors());
     }
-    case MessageType::MESSAGE_WITH_AUTOGRAD_RESP: {
+    case MessageType::FORWARD_AUTOGRAD_RESP: {
       auto& rpcWithAutograd = static_cast<RpcWithAutograd&>(rpc);
 
       // Attach 'recv' autograd function.
       addRecvRpcBackward(
-          rpcWithAutograd.autogradMetadata(), rpcWithAutograd.tensors());
+          rpcWithAutograd.autogradMetadata(),
+          rpcWithAutograd.tensors(),
+          rpcWithAutograd.fromWorkerId());
 
       // Handle the original RPC.
       auto wrappedMessageType = rpcWithAutograd.wrappedMessageType();
@@ -129,9 +131,10 @@
 
     // Wrap the original rpc with autograd information.
     AutogradMetadata autogradMetadata(
-        autogradContext.context_id(), autogradContainer.newAutogradMessageId());
+        autogradContext.contextId(), autogradContainer.newAutogradMessageId());
     RpcWithAutograd rpcWithAutograd(
-        MessageType::MESSAGE_WITH_AUTOGRAD_REQ,
+        agent.getWorkerInfo().id_,
+        MessageType::FORWARD_AUTOGRAD_REQ,
         autogradMetadata,
         std::move(scriptCall));
 
@@ -157,16 +160,16 @@
   auto& ctx = RRefContext::getInstance();
   // TODO: support creating RRefs on a local object.
   TORCH_INTERNAL_ASSERT(
-      ctx->getWorkerId() != dst.id_,
+      ctx.getWorkerId() != dst.id_,
       "Does not support creating RRef on self yet.");
-  auto userRRef = ctx->createUserRRef<IValue>(dst.id_);
+  auto userRRef = ctx.createUserRRef<IValue>(dst.id_);
   auto fm = agent.send(
       dst,
       ScriptRemoteCall(
           op, std::move(stack), userRRef->rrefId(), userRRef->forkId())
           .toMessage());
 
-  ctx->addPendingUser(userRRef->forkId(), userRRef);
+  ctx.addPendingUser(userRRef->forkId(), userRRef);
   fm->addCallback(finishAcceptUserRRef);
   return PyRRef(userRRef);
 }
@@ -192,9 +195,9 @@
   auto& ctx = RRefContext::getInstance();
   // TODO: support creating RRefs on a local object.
   TORCH_INTERNAL_ASSERT(
-      ctx->getWorkerId() != dst.id_,
+      ctx.getWorkerId() != dst.id_,
       "Does not support creating RRef on self yet.");
-  auto userRRef = ctx->createUserRRef<py::object>(dst.id_);
+  auto userRRef = ctx.createUserRRef<py::object>(dst.id_);
   auto fm = agent.send(
       dst,
       PythonRemoteCall(
@@ -203,7 +206,7 @@
           userRRef->forkId().toIValue())
           .toMessage());
 
-  ctx->addPendingUser(userRRef->forkId(), userRRef);
+  ctx.addPendingUser(userRRef->forkId(), userRRef);
   fm->addCallback(finishAcceptUserRRef);
   return PyRRef(userRRef);
 }
diff --git a/torch/csrc/distributed/rpc/request_callback.cpp b/torch/csrc/distributed/rpc/request_callback.cpp
index cdb9a89..9a95484 100644
--- a/torch/csrc/distributed/rpc/request_callback.cpp
+++ b/torch/csrc/distributed/rpc/request_callback.cpp
@@ -26,6 +26,8 @@
   try {
     return processMessage(request);
   } catch (std::exception& e) {
+    LOG(ERROR) << "Received error while processing request type "
+               << request.type() << ": " << e.what();
     return createException(request, e);
   }
 }
diff --git a/torch/csrc/distributed/rpc/request_callback_impl.cpp b/torch/csrc/distributed/rpc/request_callback_impl.cpp
index c222175..64f2809 100644
--- a/torch/csrc/distributed/rpc/request_callback_impl.cpp
+++ b/torch/csrc/distributed/rpc/request_callback_impl.cpp
@@ -2,13 +2,16 @@
 #include <c10/util/C++17.h>
 #include <torch/csrc/distributed/autograd/context/dist_autograd_container.h>
 #include <torch/csrc/distributed/autograd/context/dist_autograd_context.h>
+#include <torch/csrc/distributed/autograd/engine/dist_engine.h>
+#include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.h>
+#include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_resp.h>
+#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
 #include <torch/csrc/distributed/autograd/utils.h>
 #include <torch/csrc/distributed/rpc/future_message.h>
 #include <torch/csrc/distributed/rpc/python_remote_call.h>
 #include <torch/csrc/distributed/rpc/python_rpc_handler.h>
 #include <torch/csrc/distributed/rpc/python_udf_call.h>
 #include <torch/csrc/distributed/rpc/python_udf_resp.h>
-#include <torch/csrc/distributed/rpc/rpc_with_autograd.h>
 #include <torch/csrc/distributed/rpc/rref.h>
 #include <torch/csrc/distributed/rpc/rref_context.h>
 #include <torch/csrc/distributed/rpc/rref_proto.h>
@@ -61,7 +64,7 @@
       auto& src = static_cast<ScriptRemoteCall&>(rpc);
       auto& ctx = RRefContext::getInstance();
 
-      auto ownerRRef = ctx->getOrCreateOwnerRRef<IValue>(src.retRRefId());
+      auto ownerRRef = ctx.getOrCreateOwnerRRef<IValue>(src.retRRefId());
 
       // TODO: make this asynchronous
       // src is only alive within this block, use reference to avoid copy
@@ -75,7 +78,7 @@
           stack.size());
 
       ownerRRef->setValue(std::move(stack.front()));
-      ctx->addForkOfOwner(src.retRRefId(), src.retForkId());
+      ctx.addForkOfOwner(src.retRRefId(), src.retForkId());
       return c10::guts::make_unique<RemoteRet>(
           src.retRRefId(), src.retForkId());
     }
@@ -86,10 +89,10 @@
       auto forkId = ForkId::fromIValue(prc.retForkId());
       auto& ctx = RRefContext::getInstance();
 
-      auto ownerRRef = ctx->getOrCreateOwnerRRef<py::object>(rrefId);
+      auto ownerRRef = ctx.getOrCreateOwnerRRef<py::object>(rrefId);
       ownerRRef->setValue(
           PythonRpcHandler::getInstance().runPythonUDF(prc.serializedPyObj()));
-      ctx->addForkOfOwner(rrefId, forkId);
+      ctx.addForkOfOwner(rrefId, forkId);
       return c10::guts::make_unique<RemoteRet>(rrefId, forkId);
     }
     case MessageType::SCRIPT_RREF_FETCH_CALL: {
@@ -97,7 +100,7 @@
       auto& ctx = RRefContext::getInstance();
       // TODO: make this asynchronous
       std::shared_ptr<OwnerRRef<IValue>> rref =
-          ctx->getOrCreateOwnerRRef<IValue>(srf.rrefId());
+          ctx.getOrCreateOwnerRRef<IValue>(srf.rrefId());
       return c10::guts::make_unique<RRefFetchRet>(
           RRefFetchRet({rref->getValue()}));
     }
@@ -106,7 +109,7 @@
       auto& ctx = RRefContext::getInstance();
       // TODO: make this asynchronous
       std::shared_ptr<OwnerRRef<py::object>> rref =
-          ctx->getOrCreateOwnerRRef<py::object>(prf.rrefId());
+          ctx.getOrCreateOwnerRRef<py::object>(prf.rrefId());
       SerializedPyObj result =
           PythonRpcHandler::getInstance().serialize(rref->getValue());
       return c10::guts::make_unique<RRefFetchRet>(
@@ -115,28 +118,30 @@
     case MessageType::RREF_USER_DELETE: {
       auto& rud = static_cast<RRefUserDelete&>(rpc);
       auto& ctx = RRefContext::getInstance();
-      ctx->delForkOfOwner(rud.rrefId(), rud.forkId());
+      ctx.delForkOfOwner(rud.rrefId(), rud.forkId());
       return c10::guts::make_unique<RRefAck>();
     }
     case MessageType::RREF_CHILD_ACCEPT: {
       auto& rca = static_cast<RRefChildAccept&>(rpc);
       auto& ctx = RRefContext::getInstance();
-      ctx->delPendingChild(rca.forkId());
+      ctx.delPendingChild(rca.forkId());
       return c10::guts::make_unique<RRefAck>();
     }
     case MessageType::RREF_FORK_REQUEST: {
       auto& rfr = static_cast<RRefForkRequest&>(rpc);
       auto& ctx = RRefContext::getInstance();
-      ctx->addForkOfOwner(rfr.rrefId(), rfr.forkId());
+      ctx.addForkOfOwner(rfr.rrefId(), rfr.forkId());
       return c10::guts::make_unique<RRefAck>();
     }
-    case MessageType::MESSAGE_WITH_AUTOGRAD_REQ: {
+    case MessageType::FORWARD_AUTOGRAD_REQ: {
       auto& rpcWithAutograd = static_cast<RpcWithAutograd&>(rpc);
       const auto& autogradMetadata = rpcWithAutograd.autogradMetadata();
 
       // Attach 'recv' autograd function.
       DistAutogradContext* autogradContext = addRecvRpcBackward(
-          rpcWithAutograd.autogradMetadata(), rpcWithAutograd.tensors());
+          rpcWithAutograd.autogradMetadata(),
+          rpcWithAutograd.tensors(),
+          rpcWithAutograd.fromWorkerId());
 
       // Process the original RPC.
       auto wrappedMessageType = rpcWithAutograd.wrappedMessageType();
@@ -151,7 +156,8 @@
           autogradContainer.newAutogradMessageId());
 
       auto response = c10::guts::make_unique<RpcWithAutograd>(
-          MessageType::MESSAGE_WITH_AUTOGRAD_RESP,
+          rpc::RpcAgent::getDefaultRpcAgent()->getWorkerInfo().id_,
+          MessageType::FORWARD_AUTOGRAD_RESP,
           responseAutogradMetadata,
           std::move(wrappedRpcResponse));
 
@@ -162,6 +168,29 @@
       }
       return std::move(response);
     }
+    case MessageType::BACKWARD_AUTOGRAD_REQ: {
+      auto& gradientsCall = static_cast<PropagateGradientsReq&>(rpc);
+      const auto& autogradMetadata = gradientsCall.getAutogradMetadata();
+
+      // Retrieve the appropriate autograd context.
+      auto& autogradContext =
+          DistAutogradContainer::getInstance().retrieveContext(
+              autogradMetadata.autogradContextId);
+
+      // Lookup the appropriate 'send' function to enqueue.
+      std::shared_ptr<SendRpcBackward> sendFunction =
+          autogradContext.retrieveSendFunction(
+              autogradMetadata.autogradMessageId);
+
+      // Attach the gradients to the send function.
+      sendFunction->setGrads(gradientsCall.getGrads());
+
+      // Now execute the autograd graph using the "distributed engine."
+      DistEngine::getInstance().executeSendFunction(
+          autogradContext, sendFunction);
+
+      return c10::guts::make_unique<PropagateGradientsResp>();
+    }
     default: {
       TORCH_INTERNAL_ASSERT(
           false, "Request type ", messageType, " not supported.");
diff --git a/torch/csrc/distributed/rpc/rpc_agent.cpp b/torch/csrc/distributed/rpc/rpc_agent.cpp
index fb57678..97cbabb 100644
--- a/torch/csrc/distributed/rpc/rpc_agent.cpp
+++ b/torch/csrc/distributed/rpc/rpc_agent.cpp
@@ -15,6 +15,20 @@
   return workerInfo_;
 }
 
+std::shared_ptr<RpcAgent> RpcAgent::defaultRpcAgent_ = nullptr;
+
+std::shared_ptr<RpcAgent> RpcAgent::getDefaultRpcAgent() {
+  TORCH_INTERNAL_ASSERT(
+      defaultRpcAgent_, "Default rpc agent is not initialized!");
+  return defaultRpcAgent_;
+}
+
+void RpcAgent::setDefaultRpcAgent(std::shared_ptr<RpcAgent> defaultRpcAgent) {
+  TORCH_INTERNAL_ASSERT(
+      !defaultRpcAgent_, "Default rpc agent is already initialized!");
+  defaultRpcAgent_ = std::move(defaultRpcAgent);
+}
+
 } // namespace rpc
 } // namespace distributed
 } // namespace torch
diff --git a/torch/csrc/distributed/rpc/rpc_agent.h b/torch/csrc/distributed/rpc/rpc_agent.h
index 35eb2b5..0ee7f71 100644
--- a/torch/csrc/distributed/rpc/rpc_agent.h
+++ b/torch/csrc/distributed/rpc/rpc_agent.h
@@ -6,13 +6,14 @@
 #include <torch/csrc/distributed/rpc/types.h>
 
 #include <algorithm>
+#include <cctype>
 
 namespace torch {
 namespace distributed {
 namespace rpc {
 
 // A globally unique ID to identify an RpcAgent
-struct WorkerInfo {
+struct TORCH_API WorkerInfo {
   WorkerInfo(std::string name, int id)
       : WorkerInfo(std::move(name), (worker_id_t)id) {
     TORCH_CHECK(
@@ -50,7 +51,7 @@
 // will invoke the given ``RequestCallback`` to process received requests. It
 // should immediately become ready to serve request and accept response after
 // construction.
-class RpcAgent {
+class TORCH_API RpcAgent {
  public:
   // `WorkerInfo` is the globally unique identifier for this RpcAgent instance.
   // It contains a ``name_`` field and an ``id_`` field. ``name_`` is the
@@ -97,10 +98,17 @@
   // all ``RpcAgent``s reach this method and send all pending messages.
   virtual void sync() = 0;
 
+  static void setDefaultRpcAgent(std::shared_ptr<RpcAgent> defaultRpcAgent);
+
+  static std::shared_ptr<RpcAgent> getDefaultRpcAgent();
+
  protected:
   const WorkerInfo workerInfo_;
   const std::string workerName_;
   const std::unique_ptr<RequestCallback> cb_;
+
+ private:
+  static std::shared_ptr<RpcAgent> defaultRpcAgent_;
 };
 
 } // namespace rpc
diff --git a/torch/csrc/distributed/rpc/rpc_with_autograd.cpp b/torch/csrc/distributed/rpc/rpc_with_autograd.cpp
deleted file mode 100644
index 70296d8..0000000
--- a/torch/csrc/distributed/rpc/rpc_with_autograd.cpp
+++ /dev/null
@@ -1,163 +0,0 @@
-#include <torch/csrc/distributed/rpc/rpc_with_autograd.h>
-#include <c10/util/C++17.h>
-#include <torch/csrc/distributed/rpc/utils.h>
-#include <torch/csrc/utils/byte_order.h>
-
-namespace torch {
-namespace distributed {
-namespace rpc {
-
-constexpr int kAutogradMessageSize = 17;
-
-AutogradMetadata::AutogradMetadata(
-    int64_t autogradContextId_,
-    int64_t autogradMessageId_)
-    : autogradContextId(autogradContextId_),
-      autogradMessageId(autogradMessageId_) {}
-
-RpcWithAutograd::RpcWithAutograd(
-    MessageType messageType,
-    const AutogradMetadata& autogradMetadata,
-    std::unique_ptr<RpcCommandBase> wrappedRpc)
-    : messageType_(messageType), autogradMetadata_(autogradMetadata) {
-  TORCH_INTERNAL_ASSERT(wrappedRpc != nullptr, "wrappedRpc cannot be null!");
-  TORCH_INTERNAL_ASSERT(
-      messageType_ == MessageType::MESSAGE_WITH_AUTOGRAD_REQ ||
-      messageType_ == MessageType::MESSAGE_WITH_AUTOGRAD_RESP);
-  wrappedMessage_ = std::move(*wrappedRpc).toMessage();
-  tensors_ = wrappedMessage_.tensors();
-  wrappedMessageType_ = wrappedMessage_.type();
-}
-
-RpcWithAutograd::RpcWithAutograd(
-    MessageType messageType,
-    const AutogradMetadata& autogradMetadata,
-    std::unique_ptr<RpcCommandBase> wrappedRpc,
-    MessageType wrappedMessageType,
-    std::vector<torch::Tensor> tensors)
-    : messageType_(messageType),
-      autogradMetadata_(autogradMetadata),
-      wrappedRpc_(std::move(wrappedRpc)),
-      wrappedMessageType_(wrappedMessageType),
-      tensors_(std::move(tensors)) {
-  TORCH_INTERNAL_ASSERT(wrappedRpc_ != nullptr, "wrappedRpc cannot be null!");
-  TORCH_INTERNAL_ASSERT(
-      messageType_ == MessageType::MESSAGE_WITH_AUTOGRAD_REQ ||
-      messageType_ == MessageType::MESSAGE_WITH_AUTOGRAD_RESP);
-}
-
-Message RpcWithAutograd::toMessage() && {
-  auto messageId = wrappedMessage_.id();
-  auto messageType = wrappedMessage_.type();
-
-  auto payload = std::move(wrappedMessage_).movePayload();
-  TORCH_INTERNAL_ASSERT(!payload.empty());
-
-  // We append the message type (1 byte), autograd context id(8 bytes) and
-  // autograd message id(8 bytes) to the original message in network byte order
-  // (big endian).
-  size_t writableIndex = payload.size();
-
-  // Need 17 additional bytes.
-  payload.resize(payload.size() + kAutogradMessageSize);
-
-  // Add message type.
-  payload[writableIndex++] = messageType;
-
-  // Add autograd ids.
-  torch::utils::THP_encodeInt64Buffer(
-      reinterpret_cast<uint8_t*>(payload.data()) + writableIndex,
-      &autogradMetadata_.autogradContextId,
-      torch::utils::THPByteOrder::THP_BIG_ENDIAN,
-      1);
-  writableIndex += sizeof(int64_t);
-  torch::utils::THP_encodeInt64Buffer(
-      reinterpret_cast<uint8_t*>(payload.data()) + writableIndex,
-      &autogradMetadata_.autogradMessageId,
-      torch::utils::THPByteOrder::THP_BIG_ENDIAN,
-      1);
-
-  return Message(
-      std::move(payload), std::move(tensors_), messageType_, messageId);
-}
-
-std::unique_ptr<RpcWithAutograd> RpcWithAutograd::fromMessage(
-    const Message& message) {
-  MessageType originalMessageType = message.type();
-  TORCH_INTERNAL_ASSERT(
-      MessageType::MESSAGE_WITH_AUTOGRAD_REQ == originalMessageType ||
-      MessageType::MESSAGE_WITH_AUTOGRAD_RESP == originalMessageType);
-
-  std::vector<torch::Tensor> tensors = message.tensors();
-  int64_t messageId = message.id();
-  // Decode message type, autograd context id and autograd message id.
-  auto payload = message.payload();
-  TORCH_INTERNAL_ASSERT(payload.size() > kAutogradMessageSize);
-
-  int64_t autogradContextId, autogradMessageId;
-  // autograd message id.
-  size_t indexToRead = payload.size() - sizeof(int64_t);
-  TORCH_INTERNAL_ASSERT(indexToRead >= 0);
-  torch::utils::THP_decodeInt64Buffer(
-      &autogradMessageId,
-      reinterpret_cast<uint8_t*>(payload.data()) + indexToRead,
-      torch::utils::THPByteOrder::THP_BIG_ENDIAN,
-      1);
-
-  // autograd context id.
-  indexToRead -= sizeof(int64_t);
-  TORCH_INTERNAL_ASSERT(indexToRead >= 0);
-  torch::utils::THP_decodeInt64Buffer(
-      &autogradContextId,
-      reinterpret_cast<uint8_t*>(payload.data()) + indexToRead,
-      torch::utils::THPByteOrder::THP_BIG_ENDIAN,
-      1);
-
-  // message type.
-  indexToRead -= 1;
-  TORCH_INTERNAL_ASSERT(indexToRead >= 0);
-  MessageType wrappedMessageType =
-      static_cast<MessageType>(payload[indexToRead]);
-
-  // Remove the autograd information.
-  payload.resize(payload.size() - kAutogradMessageSize);
-
-  // Create new message type and build wrapped RPC.
-  Message wrappedMessage(
-      std::move(payload), std::move(tensors), wrappedMessageType, messageId);
-
-  std::unique_ptr<RpcCommandBase> wrappedRpc;
-  if (originalMessageType == MessageType::MESSAGE_WITH_AUTOGRAD_REQ) {
-    wrappedRpc = deserializeRequest(wrappedMessage);
-  } else {
-    wrappedRpc = deserializeResponse(wrappedMessage);
-  }
-
-  return c10::guts::make_unique<RpcWithAutograd>(
-      originalMessageType,
-      AutogradMetadata(autogradContextId, autogradMessageId),
-      std::move(wrappedRpc),
-      wrappedMessageType,
-      std::move(tensors));
-}
-
-std::vector<torch::Tensor>& RpcWithAutograd::tensors() {
-  return tensors_;
-}
-
-const AutogradMetadata& RpcWithAutograd::autogradMetadata() const {
-  return autogradMetadata_;
-}
-
-RpcCommandBase& RpcWithAutograd::wrappedRpc() {
-  TORCH_INTERNAL_ASSERT(wrappedRpc_ != nullptr, "wrappedRpc cannot be null!");
-  return *wrappedRpc_;
-}
-
-MessageType RpcWithAutograd::wrappedMessageType() const {
-  return wrappedMessageType_;
-}
-
-} // namespace rpc
-} // namespace distributed
-} // namespace torch
diff --git a/torch/csrc/distributed/rpc/rpc_with_autograd.h b/torch/csrc/distributed/rpc/rpc_with_autograd.h
deleted file mode 100644
index 41d3a40..0000000
--- a/torch/csrc/distributed/rpc/rpc_with_autograd.h
+++ /dev/null
@@ -1,75 +0,0 @@
-#pragma once
-
-#include <torch/csrc/distributed/rpc/rpc_command_base.h>
-
-namespace torch {
-namespace distributed {
-namespace rpc {
-
-struct TORCH_API AutogradMetadata {
-  AutogradMetadata(int64_t autogradContextId, int64_t autogradMessageId);
-
-  // autogradContextId_ is a globally unique integer that identifies a
-  // particular distributed autograd pass.
-  int64_t autogradContextId;
-  // autogradMessageId_ is a globally unique integer that identifies a pair
-  // of send/recv autograd functions.
-  int64_t autogradMessageId;
-};
-
-// Represents an RPC that includes autograd information. This class basically
-// wraps another `RpcCommandBase` object which represents the actual RPC and has
-// additional autograd information associated with that RPC.
-class TORCH_API RpcWithAutograd final : public RpcCommandBase {
- public:
-  // Used when we are sending an RPC over the wire.
-  RpcWithAutograd(
-      MessageType messageType,
-      const AutogradMetadata& autogradMetadata,
-      std::unique_ptr<RpcCommandBase> wrappedRpc);
-
-  // Used when receiving an RPC over the wire.
-  RpcWithAutograd(
-      MessageType messageType,
-      const AutogradMetadata& autogradMetadata,
-      std::unique_ptr<RpcCommandBase> wrappedRpc,
-      MessageType wrappedMessageType,
-      std::vector<torch::Tensor> tensors);
-
-  Message toMessage() && override;
-
-  static std::unique_ptr<RpcWithAutograd> fromMessage(const Message& message);
-
-  // Retrieves tensors as part of this RPC, which need to be considered for
-  // autograd computations.
-  std::vector<torch::Tensor>& tensors();
-
-  const AutogradMetadata& autogradMetadata() const;
-
-  RpcCommandBase& wrappedRpc();
-
-  // Message type of the wrapped RPC.
-  MessageType wrappedMessageType() const;
-
- private:
-  // Message type for this call.
-  MessageType messageType_;
-
-  AutogradMetadata autogradMetadata_;
-  std::unique_ptr<RpcCommandBase> wrappedRpc_;
-
-  // Serialized message representing wrappedRpc_. Used mostly as a cache to
-  // avoid serializing the request twice.
-  Message wrappedMessage_;
-
-  // message type of the wrappedMessage, this is stored separately since
-  // wrappedMessage_ is not always guaranteed to be populated.
-  MessageType wrappedMessageType_;
-
-  // Tensors part of the wrappedRpc that need to be considered for autograd.
-  std::vector<torch::Tensor> tensors_;
-};
-
-} // namespace rpc
-} // namespace distributed
-} // namespace torch
diff --git a/torch/csrc/distributed/rpc/rref.cpp b/torch/csrc/distributed/rpc/rref.cpp
index fc34f1b..34a39fa 100644
--- a/torch/csrc/distributed/rpc/rref.cpp
+++ b/torch/csrc/distributed/rpc/rref.cpp
@@ -93,7 +93,7 @@
 RRefForkData RRef::fork() const {
   auto& ctx = RRefContext::getInstance();
   return RRefForkData(
-      ownerId_, rrefId_, ctx->genGloballyUniqueId(), ctx->getWorkerId());
+      ownerId_, rrefId_, ctx.genGloballyUniqueId(), ctx.getWorkerId());
 }
 
 //////////////////////////  UserRRef  /////////////////////////////////////
@@ -115,9 +115,9 @@
 UserRRef<T>::~UserRRef() {
   // TODO: queue this in RRefContext instead of doing it here.
   auto& ctx = RRefContext::getInstance();
-  if (ctx->getWorkerId() != ownerId_) {
-    auto fm = ctx->agent()->send(
-        ctx->agent()->getWorkerInfo(ownerId_),
+  if (ctx.getWorkerId() != ownerId_) {
+    auto fm = ctx.agent()->send(
+        ctx.agent()->getWorkerInfo(ownerId_),
         RRefUserDelete(rrefId_, forkId_).toMessage());
 
     fm->addCallback(
@@ -132,7 +132,7 @@
 
 template <>
 IValue UserRRef<IValue>::toHere() {
-  auto& agent = RRefContext::getInstance()->agent();
+  auto agent = RpcAgent::getDefaultRpcAgent();
   std::shared_ptr<FutureMessage> fm = agent->send(
       agent->getWorkerInfo(ownerId_),
       ScriptRRefFetchCall(rrefId()).toMessage());
@@ -148,7 +148,7 @@
 
 template <>
 py::object UserRRef<py::object>::toHere() {
-  auto& agent = RRefContext::getInstance()->agent();
+  auto agent = RpcAgent::getDefaultRpcAgent();
   std::shared_ptr<FutureMessage> fm = agent->send(
       agent->getWorkerInfo(ownerId_),
       PythonRRefFetchCall(rrefId()).toMessage());
diff --git a/torch/csrc/distributed/rpc/rref_context.cpp b/torch/csrc/distributed/rpc/rref_context.cpp
index 5f04af9..e274fe4 100644
--- a/torch/csrc/distributed/rpc/rref_context.cpp
+++ b/torch/csrc/distributed/rpc/rref_context.cpp
@@ -7,25 +7,13 @@
 namespace distributed {
 namespace rpc {
 
-std::unique_ptr<RRefContext> RRefContext::context_ = nullptr;
-
-void RRefContext::initInstance(std::shared_ptr<RpcAgent> agent) {
-  TORCH_CHECK(!RRefContext::context_, "Can only initialize RRefContext once.");
-  TORCH_CHECK(agent, "RRefContext requires a non-null RpcAgent shared_ptr.");
-
-  RRefContext::context_ =
-      std::unique_ptr<RRefContext>(new RRefContext(std::move(agent)));
-}
-
-std::unique_ptr<RRefContext>& RRefContext::getInstance() {
-  TORCH_CHECK(
-      RRefContext::context_, "Have to initialize RRefContext before use.");
-  return RRefContext::context_;
+RRefContext& RRefContext::getInstance() {
+  static RRefContext context(RpcAgent::getDefaultRpcAgent());
+  return context;
 }
 
 void RRefContext::destroyInstance() {
-  RRefContext::getInstance()->checkRRefLeaks();
-  RRefContext::context_.reset();
+  RRefContext::getInstance().checkRRefLeaks();
 }
 
 void RRefContext::handleException(const Message& message) {
diff --git a/torch/csrc/distributed/rpc/rref_context.h b/torch/csrc/distributed/rpc/rref_context.h
index 50247e8..981e136 100644
--- a/torch/csrc/distributed/rpc/rref_context.h
+++ b/torch/csrc/distributed/rpc/rref_context.h
@@ -15,8 +15,7 @@
 // Manages RRef lifetime and keeps track of RRef forks.
 class RRefContext {
  public:
-  static void initInstance(std::shared_ptr<RpcAgent>);
-  static std::unique_ptr<RRefContext>& getInstance();
+  static RRefContext& getInstance();
   static void destroyInstance();
 
   static void handleException(const Message& message);
@@ -109,7 +108,6 @@
   // If there is any leak on any RRef, this method will throw an error.
   void checkRRefLeaks();
 
-  static std::unique_ptr<RRefContext> context_;
   static std::atomic<local_id_t> nextLocalId_;
 
   const std::shared_ptr<RpcAgent> agent_;
diff --git a/torch/csrc/distributed/rpc/utils.cpp b/torch/csrc/distributed/rpc/utils.cpp
index def4892..f843f3a 100644
--- a/torch/csrc/distributed/rpc/utils.cpp
+++ b/torch/csrc/distributed/rpc/utils.cpp
@@ -1,8 +1,9 @@
 #include <torch/csrc/distributed/rpc/utils.h>
+#include <torch/csrc/distributed/autograd/rpc_messages/propagate_gradients_req.h>
+#include <torch/csrc/distributed/autograd/rpc_messages/rpc_with_autograd.h>
 #include <torch/csrc/distributed/rpc/python_remote_call.h>
 #include <torch/csrc/distributed/rpc/python_udf_call.h>
 #include <torch/csrc/distributed/rpc/python_udf_resp.h>
-#include <torch/csrc/distributed/rpc/rpc_with_autograd.h>
 #include <torch/csrc/distributed/rpc/rref_proto.h>
 #include <torch/csrc/distributed/rpc/script_call.h>
 #include <torch/csrc/distributed/rpc/script_remote_call.h>
@@ -41,8 +42,11 @@
     case MessageType::RREF_FORK_REQUEST: {
       return RRefForkRequest::fromMessage(request);
     }
-    case MessageType::MESSAGE_WITH_AUTOGRAD_REQ: {
-      return RpcWithAutograd::fromMessage(request);
+    case MessageType::FORWARD_AUTOGRAD_REQ: {
+      return autograd::RpcWithAutograd::fromMessage(request);
+    }
+    case MessageType::BACKWARD_AUTOGRAD_REQ: {
+      return autograd::PropagateGradientsReq::fromMessage(request);
     }
     default: {
       TORCH_INTERNAL_ASSERT(
@@ -72,8 +76,11 @@
       std::string err(response.payload().begin(), response.payload().end());
       throw std::runtime_error(err);
     }
-    case MessageType::MESSAGE_WITH_AUTOGRAD_RESP: {
-      return RpcWithAutograd::fromMessage(response);
+    case MessageType::FORWARD_AUTOGRAD_RESP: {
+      return autograd::RpcWithAutograd::fromMessage(response);
+    }
+    case MessageType::BACKWARD_AUTOGRAD_RESP: {
+      return autograd::RpcWithAutograd::fromMessage(response);
     }
     default: {
       TORCH_INTERNAL_ASSERT(
diff --git a/torch/distributed/autograd/__init__.py b/torch/distributed/autograd/__init__.py
index 1a51481..0c00d34 100644
--- a/torch/distributed/autograd/__init__.py
+++ b/torch/distributed/autograd/__init__.py
@@ -8,8 +8,10 @@
     worker stores metadata associated with this context_id, which is required
     to correctly execute a distributed autograd pass.
 
-    This is only needed in the "FAST" mode for distributed autograd, where we
-    assume all RPC communication is would also be part of the backward pass.
+    This is only needed in the "FAST" mode (as described in
+    https://github.com/pytorch/pytorch/issues/23110) for distributed autograd,
+    where we assume all RPC communication is would also be part of the backward
+    pass.
 
     Example::
         >> import torch.distributed.autograd as dist_autograd
@@ -25,3 +27,34 @@
 
     def __exit__(self, type, value, traceback):
         _release_context(self.autograd_context._context_id())
+
+
+def backward(roots):
+    '''
+    Kicks off the distributed backward pass using the provided roots. This
+    currently implements the "FAST" mode
+    (see https://github.com/pytorch/pytorch/issues/23110) algorithm which
+    assumes all RPC messages sent in the same distributed autograd context
+    across workers would be part of the autograd graph during the backward pass.
+
+    We use the provided roots to discover the autograd graph and compute
+    appropriate dependencies. This method blocks until the entire
+    autograd computation is done.
+
+    We accumulate the gradients in the appropriate "autograd context id" on each
+    of the nodes. The autograd context id used is the current autograd context
+    id of this node when backward() is called. If there is no valid autograd
+    context id, we throw an error.
+
+    Arguments:
+        roots: List of tensors which represent the roots of the autograd
+            computation. All the tensors should be scalars.
+
+    Example::
+        >> import torch.distributed.autograd as dist_autograd
+        >> with dist_autograd.context() as context_id:
+        >>      pred = model.forward()
+        >>      loss = loss_func(pred, loss)
+        >>      dist_autograd.backward(loss)
+    '''
+    _backward(roots)
diff --git a/torch/distributed/rpc/api.py b/torch/distributed/rpc/api.py
index c80653c..34603b1 100644
--- a/torch/distributed/rpc/api.py
+++ b/torch/distributed/rpc/api.py
@@ -1,6 +1,7 @@
 from torch.distributed import invoke_rpc_builtin, invoke_rpc_python_udf
 from torch.distributed import invoke_remote_builtin, invoke_remote_python_udf
-from torch.distributed import _init_rref_context, _destroy_rref_context
+from torch.distributed import _init_rpc_agent
+from torch.distributed import _destroy_rref_context
 from torch.distributed import ProcessGroupAgent
 from torch.distributed import WorkerInfo
 from .backend_registry import is_backend_registered, init_backend
@@ -77,7 +78,6 @@
                                self_rank, group.rank()))
         # TODO: add try-except and destroy _agent in all processes if any fails.
         _agent = ProcessGroupAgent(self_name, group, num_send_recv_threads)
-        _init_rref_context(_agent)
     elif is_backend_registered(backend):
         _agent = init_backend(
             backend,
@@ -85,9 +85,9 @@
             self_name=self_name,
             init_method=init_method
         )
-        _init_rref_context(_agent)
     else:
         raise RuntimeError("Unrecognized RPC backend ", backend)
+    _init_rpc_agent(_agent)
 
 
 @_require_initialized