[ONNX] Redesign inplace conversion (#55033) (#56173)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56173

* Create `InplaceConverter` and `ValueTracker` to keep track of aliases of values throughout the graph. For a given value, a new alias is created every time when there is an inplace operation, SetAttr, or through nested blocks owned by If/Loop nodes.
* Fix bug where controlflow node output types are not set, when the complete node is unable to run ONNX shape inference due to containing non-onnx node.
* Add symbolic for `__not__` ~~and `prim_min`~~(update: moved to a separate PR), and update `index_put` opset9 to support case of assignment without providing indices.
* Bump ORT version in CI test.

Test Plan: Imported from OSS

Reviewed By: pbelevich

Differential Revision: D27866138

Pulled By: SplitInfinity

fbshipit-source-id: ab5c9188740c50f783ceba4d54fda43c26e2fde7
diff --git a/.jenkins/caffe2/test.sh b/.jenkins/caffe2/test.sh
index c1988d2..e66b7ae 100755
--- a/.jenkins/caffe2/test.sh
+++ b/.jenkins/caffe2/test.sh
@@ -170,7 +170,7 @@
   # JIT C++ extensions require ninja, so put it into PATH.
   export PATH="/var/lib/jenkins/.local/bin:$PATH"
   if [[ "$BUILD_ENVIRONMENT" == *py3* ]]; then
-    pip install -q --user onnxruntime==1.6.0
+    pip install -q --user onnxruntime==1.7.0
   fi
   "$ROOT_DIR/scripts/onnx/test.sh"
 fi
diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py
index f33fd1b..29b8d2d 100644
--- a/test/onnx/test_pytorch_onnx_onnxruntime.py
+++ b/test/onnx/test_pytorch_onnx_onnxruntime.py
@@ -4569,6 +4569,24 @@
         y = torch.randn(4, 5)
         self.run_test(model, (x, y))
 
+    @skipIfUnsupportedMinOpsetVersion(14)  # Need onnx::identity of sequence in opset 14
+    def test_list_append_nested_2(self):
+        class ListModel(torch.nn.Module):
+            def forward(self, x):
+                res = []
+                res_replicate = []
+                for i in range(x.size(0)):
+                    if len(res) > 2:
+                        for j in range(x.size(1)):
+                            res.append(x[i][j])
+                        res_replicate.append(res[-1])
+                        res.append(res_replicate[-1])
+                return res, res_replicate
+
+        model = torch.jit.script(ListModel())
+        x = torch.randn(4, 4, 3, 4)
+        self.run_test(model, (x, ))
+
     @skipIfUnsupportedMinOpsetVersion(11)
     def test_list_pop(self):
         class ListModel(torch.nn.Module):
@@ -4651,6 +4669,36 @@
         y = torch.randn(4, 5)
         self.run_test(model, (x, y))
 
+    @skipIfUnsupportedMinOpsetVersion(11)
+    def test_list_set(self):
+        class ListModel(torch.nn.Module):
+            def forward(self, x, y):
+                res = []
+                for i in range(x.size(0)):
+                    res.append(x[i])
+                res[y] = x[y]
+                return res
+
+        model = torch.jit.script(ListModel())
+        x = torch.randn(12, 4)
+        y = torch.tensor(2, dtype=torch.long)
+        self.run_test(model, (x, y))
+
+    @skipIfUnsupportedMinOpsetVersion(13)
+    def test_list_idx_sum(self):
+        class ListModel(torch.nn.Module):
+            def forward(self, x, y):
+                indices = torch.arange(x.size(0))
+                res = []
+                for i in range(x.size(0)):
+                    res.append(x[i])
+                return res[torch.sum(indices[:y])]
+
+        model = torch.jit.script(ListModel())
+        x = torch.randn(12, 4)
+        y = torch.tensor(2, dtype=torch.long)
+        self.run_test(model, (x, y))
+
     @skipIfUnsupportedMinOpsetVersion(9)
     def test_tensor_factories(self):
         class TensorFactory(torch.nn.Module):
@@ -4830,6 +4878,125 @@
         self.run_test(InplaceAddModel(), (x, y), rtol=1e-2, atol=1e-2)
         self.run_test(InplaceMulModel(), (x, y), rtol=1e-2, atol=1e-2)
 
+    @skipIfUnsupportedMinOpsetVersion(9)
+    def test_inplace_with_loop(self):
+        class M(torch.nn.Module):
+            def forward(self, x):
+                a = torch.ones(12,)
+                for i in range(10):
+                    a.add_(torch.ones(12,))
+                return a + x
+
+        m = M()
+        x = torch.randn(12,)
+        self.run_test(torch.jit.script(M()), (x))
+
+    @skipIfUnsupportedMinOpsetVersion(9)
+    def test_inplace_with_loop_2(self):
+        class M(torch.nn.Module):
+            def forward(self, x):
+                _bias = torch.ones(12,)
+                a = torch.ones(12,)  # used in loop, altered.
+                a_ref = a  # not used in loop, should be altered.
+                b = x.clone()  # used in loop, not be altered.
+                b_ref = b  # not used in loop, should not be altered.
+                for i in range(10):
+                    if i == 3:
+                        for j in range(5):
+                            a += _bias
+                            _bias.add_(torch.ones(12,))
+                            b = b + torch.ones(12,)
+
+                    _bias.add_(torch.ones(12,))
+                    a += _bias
+                # TODO: value for a_ref is incorrect.
+                # a_ref += torch.ones(12,)
+                b_ref += torch.ones(12,)
+                return _bias + x, a, b, b_ref
+
+        m = M()
+        x = torch.zeros(12,)
+        self.run_test(torch.jit.script(M()), (x))
+
+    @skipIfUnsupportedMinOpsetVersion(11)
+    def test_inplace_attr_with_loop(self):
+        class M(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self._bias = torch.arange(12,)
+
+            def forward(self, x):
+                self._bias = torch.arange(12,)
+                for i in range(10):
+                    if i == 3:
+                        for j in range(5):
+                            self._bias += torch.arange(12,)
+                return self._bias + x
+
+        m = M()
+        x = torch.zeros(12,)
+        self.run_test(torch.jit.script(M()), (x))
+
+    @skipIfUnsupportedMinOpsetVersion(11)
+    def test_inplace_attr_copy_with_loop(self):
+        class M(torch.nn.Module):
+            def __init__(self):
+                super().__init__()
+                self._bias = torch.arange(12,)
+
+            def forward(self, x):
+                self._bias = torch.arange(12,)
+                for i in range(10):
+                    if i == 3:
+                        for j in range(5):
+                            self._bias.copy_(torch.arange(12,))
+                        self._bias.copy_(self._bias + torch.arange(12,))
+
+                    self._bias.copy_(self._bias + torch.arange(12,))
+                return self._bias + x
+
+        m = M()
+        x = torch.zeros(12,)
+        self.run_test(torch.jit.script(M()), (x))
+
+    @skipIfUnsupportedMinOpsetVersion(14)  # Need onnx::identity of sequence in opset 14
+    def test_inplace_sequence_with_loop(self):
+        class M(torch.nn.Module):
+            def process(self, beam_hyps: List[torch.Tensor], done: torch.Tensor, x):
+                batch_size = x.shape[0]
+                for i in range(batch_size):
+                    if done[i]:
+                        continue
+
+                    beam_idx = 0
+                    for _, token in enumerate(x[i]):
+                        beam_hyps.append(token)
+                        beam_idx += 1
+
+                        if beam_idx == 6:
+                            break
+
+                    done[i] = len(beam_hyps) > 4
+
+                return beam_hyps, done
+
+            def forward(self, x):
+                beam_hyps: List[torch.Tensor] = []
+                batch_size = x.shape[0]
+                cur_len = 0
+                max_len = x.shape[1]
+                done = torch.zeros(batch_size, dtype=torch.bool)
+                while cur_len < max_len:
+                    beam_hyps, done = self.process(beam_hyps, done, x[:, 0, :])
+                    cur_len = cur_len + 1
+
+                return beam_hyps
+
+        m = torch.jit.script(M())
+        x = torch.randn(8, 4, 3)
+        self.run_test(torch.jit.script(M()), (x))
+
+
     @disableScriptTest()  # Sort with dynamic dim not supported in ONNX
     def test_sort(self):
         class SortModel(torch.nn.Module):
@@ -7602,6 +7769,37 @@
         self.run_test(model, (x, anchors))
 
     @skipIfUnsupportedMinOpsetVersion(11)
+    def test_set_attr_5(self):
+        class MyModule(torch.nn.Module):
+            def __init__(self):
+                super(MyModule, self).__init__()
+                self.conv = torch.nn.Conv1d(10, 3, 3)
+                self.conv.bias = torch.nn.Parameter(torch.zeros(3, 10, 3))
+
+            def set_cell_anchors(self, anchors):
+                self.conv.weight = torch.arange(10)
+                for i in range(10):
+                    if i == 3:
+                        for j in range(10):
+                            w = self.conv.weight
+                            self.conv.weight = torch.arange(10) + w
+
+                    self.conv.weight = self.conv.weight + torch.arange(10)
+                    # NOTE: `is not None` and `assert` is for passing torchscript.
+                    if self.conv.bias is not None:
+                        a = self.conv.bias
+                        assert a is not None
+                        self.conv.bias = anchors + a
+
+            def forward(self, anchors):
+                self.set_cell_anchors(anchors)
+                return self.conv.weight, self.conv.bias
+
+        model = torch.jit.script(MyModule())
+        anchors = torch.ones(3, 10, 3)
+        self.run_test(model, (anchors))
+
+    @skipIfUnsupportedMinOpsetVersion(11)
     def test_set_attr_in_loop(self):
         class MyModule(torch.nn.Module):
             def __init__(self):
@@ -7698,7 +7896,11 @@
         model = Example(10)
         random_data = torch.rand((1, 5, 30, 30))
         empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0)
-        self.run_test(model, (random_data, empty_tensor))
+        random_state = torch.rand((1, 1, 10, 30, 30))
+        self.run_test(model, (random_data, empty_tensor),
+                      input_names=['data', 'state'],
+                      dynamic_axes={'state': [0, 1, 2, 3, 4]},
+                      test_with_inputs=[(random_data, random_state)])
 
     @skipIfUnsupportedMinOpsetVersion(11)
     def test_index_put_if_3(self):
@@ -7768,6 +7970,41 @@
         empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0)
         self.run_test(model, (random_data, empty_tensor))
 
+
+    @skipIfUnsupportedMinOpsetVersion(11)
+    def test_index_put_if_5(self):
+        @torch.jit.script
+        def check_init(input_data, hidden_size, prev_state):
+            # type: (torch.Tensor, int, torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]
+            batch_size = input_data.size(0)
+            spatial_size_0 = input_data.size(2)
+            spatial_size_1 = input_data.size(3)
+            # generate empty prev_state, if None is provided
+            state_size = (2, batch_size, hidden_size, spatial_size_0, spatial_size_1)
+            state = torch.zeros(state_size, device=input_data.device)
+            state_ref = state
+            if prev_state.size(0) == 0:
+                state[:] = torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1) * 3
+                state = state + 3
+                state[:] = torch.ones(batch_size, hidden_size, spatial_size_0, spatial_size_1) * 4
+            else:
+                state = state + 2
+            return state, state_ref
+
+        class Example(torch.nn.Module):
+            def __init__(self, hidden_size):
+                super().__init__()
+                self.hidden_size = hidden_size
+
+            def forward(self, input_data, prev_state):
+                prev_state, state_ref = check_init(input_data, self.hidden_size, prev_state)
+                return prev_state, state_ref
+
+        model = Example(4)
+        random_data = torch.rand((1, 5, 4, 4))
+        empty_tensor = torch.tensor([], dtype=torch.float).view(0, 0, 0, 0, 0)
+        self.run_test(model, (random_data, empty_tensor))
+
     @skipIfUnsupportedMinOpsetVersion(11)
     def test_list_append_in_block(self):
         class ListModel(torch.nn.Module):
diff --git a/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp b/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp
index 031ad2b..18eb78f 100644
--- a/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp
+++ b/torch/csrc/jit/passes/onnx/fixup_onnx_controlflow.cpp
@@ -236,6 +236,11 @@
   // NOTE: the output order is deliberately changed to match expected order
   //       since onnx loop requires scan outputs to be the last outputs.
   auto new_outputs = ConvertSequenceDependencies(node, opset_version);
+
+  // Copy type of block output to node output.
+  for (size_t i = 0; i < node->outputs().size(); ++i) {
+    node->output(i)->setType(node->blocks().at(0)->outputs().at(i + 1)->type());
+  }
   TORCH_INTERNAL_ASSERT(output_size == new_outputs.size());
   return new_outputs;
 }
@@ -375,6 +380,11 @@
   auto* graph = if_node->owningGraph();
   FixupONNXSubblockOutputs(node);
   ONNXFixupUninitializedOutput(if_node);
+  // Copy type of block output to node output.
+  for (size_t i = 0; i < node->outputs().size(); ++i) {
+    node->output(i)->setType(node->blocks().at(0)->outputs().at(i)->type());
+  }
+
   GRAPH_DUMP("Graph after fixing controlflow: ", node->owningGraph());
   return if_node->outputs().vec();
 }
diff --git a/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp b/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp
index bc5e6cb..4a86ed3 100644
--- a/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp
+++ b/torch/csrc/jit/passes/onnx/remove_inplace_ops_for_onnx.cpp
@@ -19,17 +19,98 @@
 const std::set<c10::Symbol> inplace_ops =
     {aten::append, aten::index_put_, aten::pop, aten::insert, aten::Delete};
 
-bool IsInplaceNode(const Node* n) {
-  if (inplace_ops.find(n->kind()) != inplace_ops.end()) {
-    return true;
-  }
+// InplaceConverter defines a set of functions that together enables the
+// conversion from prim::GetAttr, prim::SetAttr, and ATen in-place operators to
+// ONNX out-place operators.
+struct InplaceConverter {
+  InplaceConverter(
+      std::shared_ptr<Graph> graph,
+      MutationRemover* mr,
+      Module* model = nullptr)
+      : graph_(std::move(graph)), mr_(mr), module_(model) {}
 
-  if (n->kind() == Symbol::fromQualString("onnx::Placeholder") &&
-      n->s(attr::name) == "index_put_") {
-    return true;
-  }
+  void convertMutationForONNX();
 
-  return false;
+ private:
+  void gatherAttrNameInitialValueMap(
+      Block* block,
+      std::unordered_map<std::string, Value*>& attr_name_value_map,
+      std::unordered_map<Node*, std::string>& attr_node_fullname_map);
+  void replaceAttrWithInplaceOps(
+      Block* block,
+      const std::unordered_map<std::string, Value*>& attr_name_value_map,
+      const std::unordered_map<Node*, std::string>& attr_node_fullname_map);
+
+  void convertInplaceOpsAndTrackAlias();
+  void convertInplaceOpsAndTrackAlias(Block* block);
+
+  void correctAliasReferences();
+  void correctAliasReferences(Block* block);
+  void correctAliasReferences(Node* n);
+
+  void convertGetSetAttrToInplaceOps(Block* block);
+
+  // ValueTracker provides apis to record aliases for a single value,
+  // and to retrieve the correct alias of any given value based on the location
+  // in the graph it is used.
+  struct ValueTracker {
+    ValueTracker() : graph_(nullptr) {}
+
+    void init(const std::shared_ptr<Graph>& graph);
+    void recordSetValue(Value* old_v, Value* new_v);
+    Value* findAliasForValueAtNode(Value* v, const Node* n) const;
+
+    std::string toString() const;
+
+   private:
+    std::shared_ptr<Graph> graph_;
+
+    // Map from aliases to root value.
+    // A single value can have multiple aliases throughout the graph,
+    // created by inplace operators, and preserved through loop carried
+    // input/output. For each such value, its first occurance will be set as
+    // root value.
+    std::unordered_map<Value*, Value*> alias_to_value_;
+
+    // Sort the alias based on their order in graph.
+    // A tie can happen when two distinct aliases belong to different blocks,
+    // while having the same ancestor node. The unique id is used as tie
+    // breaker, otherwise the two aliases will be considered equal to each
+    // other. aliasComp must satisfy strict weak ordering.
+    struct aliasComp {
+      bool operator()(const Value* a, const Value* b) const {
+        auto* n_a = a->node();
+        auto* n_b = b->node();
+        if (n_a == n_b) {
+          return false;
+        }
+        auto a_b = n_a->isBefore(n_b);
+        auto b_a = n_b->isBefore(n_a);
+        if (a_b == b_a) {
+          return a->unique() < b->unique();
+        }
+        return a_b;
+      }
+    };
+    // Map from root value to aliases sorted by their order in graph.
+    std::unordered_map<Value*, std::set<Value*, aliasComp>>
+        value_to_sorted_aliases_;
+  };
+
+  std::shared_ptr<Graph> graph_;
+  MutationRemover* mr_;
+  Module* module_;
+  ValueTracker vt_;
+};
+
+bool isAncestor(const Block* a, const Block* b) {
+  while (b && b->owningNode()) {
+    if (a == b) {
+      return true;
+    }
+    b = b->owningNode()->owningBlock();
+  }
+  return a == b;
 }
 
 Node* addDummyClone(
@@ -66,332 +147,59 @@
   return newNode;
 }
 
-// Check If then/else blocks to match the number of outputs.
-// If the number of block outputs do not match, insert a dummy
-// constant of corresponding shape and type.
-Value* MatchIfBlocksOutputForValue(
-    Value* orig_data,
-    Block* outer_block,
-    Value* origOutput) {
-  if (outer_block->owningNode()->kind() == prim::Loop)
-    return outer_block->owningNode()->outputs().at(
-        outer_block->owningNode()->outputs().size() - 1);
-
-  if (outer_block->owningNode()->kind() != prim::If)
-    return nullptr;
-  size_t output_size = outer_block->outputs().size();
-
-  for (size_t i = 0; i < output_size - 1; i++) {
-    if (outer_block->outputs().at(i)->debugNameBase() ==
-        origOutput->debugNameBase()) { // Check debug names
-      outer_block->replaceOutput(i, outer_block->outputs().at(output_size - 1));
-      outer_block->eraseOutput(output_size - 1);
-      outer_block->owningNode()->eraseOutput(output_size - 1);
-      return outer_block->owningNode()->outputs().at(i);
-    }
-  }
-
-  for (Block* b : outer_block->owningNode()->blocks()) {
-    if (b->outputs().size() < output_size) {
-      auto clone_node =
-          addDummyClone(b->owningGraph(), orig_data, false, b->return_node());
-      b->registerOutput(clone_node->output());
-      b->outputs()
-          .at(b->outputs().size() - 1)
-          ->copyMetadata(
-              outer_block->outputs().at(output_size - 1)); // Copy debug names
-    }
-  }
-  return outer_block->owningNode()->outputs().at(output_size - 1);
-}
-
-// clang-format off
-// Register inplace op node inputs/outputs through the blocks.
-// Eg. The IR before updating:
-//%23 : bool = aten::eq(%22, %13)
-// = prim::If(%23)
-//  block0():
-//    %24 : int[] = prim::ListConstruct(%batch_size.1, %6, %spatial_size_0.1, %spatial_size_1.1)
-//    %25 : Tensor = aten::ones(%24, %12, %12, %12, %12)
-//    %26 : Tensor = aten::slice(%state.1, %13, %13, %10, %11)
-//    %27 : Tensor = aten::copy_(%26, %25, %9)
-//    -> ()
-//  block1():
-//    %28 : int[] = prim::ListConstruct(%batch_size.1, %6, %spatial_size_0.1, %spatial_size_1.1)
-//    %29 : Tensor = aten::randn(%28, %12, %12, %12, %12)
-//    %30: Tensor = aten::slice(%state.1, %13, %13, %10, %11)
-//    %31 : Tensor = aten::copy_(%30, %29, %9)
-//    -> ()
-// After updating:
-//%23 : bool = aten::eq(%22, %13)
-//%51 : Tensor = prim::If(%23)
-//  block0():
-//    %24 : int[] = prim::ListConstruct(%batch_size.1, %6, %spatial_size_0.1, %spatial_size_1.1)
-//    %25 : Tensor = aten::ones(%24, %12, %12, %12, %12)
-//    %26 : Tensor = aten::slice(%state.1, %13, %13, %10, %11)
-//    %32 : Tensor?[] = prim::ListConstruct()
-//    %33 : Tensor = aten::expand_as(%25, %26)
-//    %38 : int = prim::Constant[value=0]()
-//    %39 : int = aten::size(%state.1, %38)
-//    %40 : int = prim::Constant[value=4]()
-//    %41 : None = prim::Constant()
-//    %42 : None = prim::Constant()
-//    %43 : None = prim::Constant()
-//    %44 : Tensor = aten::arange(%39, %40, %41, %42, %43)
-//    %45 : int = prim::Constant[value=0]()
-//    %46 : Tensor = aten::slice(%44, %45, %13, %10, %11)
-//    %47 : int[] = prim::Constant[value=[-1]]()
-//    %48 : Tensor = aten::view(%46, %47)
-//    %49 : Tensor?[] = prim::ListConstruct(%48)
-//    %50 : Tensor = aten::index_put(%state.1, %49, %33, %9)
-//    -> (%50)
-//  block1():
-//    %28 : int[] = prim::ListConstruct(%batch_size.1, %6, %spatial_size_0.1, %spatial_size_1.1)
-//    %29 : Tensor = aten::randn(%28, %12, %12, %12, %12)
-//    %30 : Tensor = aten::slice(%state.1, %13, %13, %10, %11)
-//    %35 : Tensor?[] = prim::ListConstruct()
-//    %36 : Tensor = aten::expand_as(%29, %30)
-//    %52 : int = prim::Constant[value=0]()
-//    %53 : int = aten::size(%state.1, %52)
-//    %54 : int = prim::Constant[value=4]()
-//    %55 : None = prim::Constant()
-//    %56 : None = prim::Constant()
-//    %57 : None = prim::Constant()
-//    %58 : Tensor = aten::arange(%53, %54, %55, %56, %57)
-//    %59 : int = prim::Constant[value=0]()
-//    %60 : Tensor = aten::slice(%58, %59, %13, %10, %11)
-//    %61 : int[] = prim::Constant[value=[-1]]()
-//    %62 : Tensor = aten::view(%60, %61)
-//    %63 : Tensor?[] = prim::ListConstruct(%62)
-//    %64 : Tensor = aten::index_put(%state.1, %63, %36, %9)
-//    -> (%64)
-// clang-format on
-void RegisterInplaceNodeInIfBlocks(
-    Value* orig_data,
-    Value* new_data,
-    const std::string& output_name) {
-  auto outer_block = new_data->node()->owningBlock();
-  auto initial_block_node = outer_block->owningNode();
-
-  if ((nullptr == initial_block_node) ||
-      (initial_block_node->kind() != prim::If)) {
-    return;
-  }
-
-  auto next_block_node = initial_block_node;
-  new_data->setDebugName("_output_" + output_name);
-  outer_block->registerOutput(new_data);
-  // Block has a new output. Add the output for the prim::If node.
-  if (next_block_node->outputs().size() < outer_block->outputs().size())
-    next_block_node->addOutput()->copyMetadata(new_data);
-
-  auto next_block = next_block_node->owningBlock();
-  while (nullptr != next_block->owningNode() &&
-         next_block != orig_data->node()->owningBlock()) {
-    next_block->registerOutput(next_block_node->output(0));
-    next_block_node = next_block->owningNode();
-    // Block has a new output. Add the output for the prim::If node.
-    if (next_block_node->outputs().size() < next_block->outputs().size())
-      next_block_node->addOutput()->setType(new_data->type());
-    next_block = next_block_node->owningBlock();
-  }
-  orig_data->replaceAllUsesAfterNodeWith(
-      next_block_node,
-      next_block_node->outputs().at(next_block_node->outputs().size() - 1));
-}
-
-// clang-format off
-// Register inplace op node inputs/outputs through the blocks.
-// Eg. The IR before updating:
-//   = prim::Loop(%10, %27)
-//    block0(%stream_idx.1 : int):
-//       = prim::Loop(%9, %27)
-//        block0(%i.1 : int):
-//          %36 : Tensor = aten::select(%bias.1, %26, %stream_idx.1)
-//          %41 : Tensor = aten::copy_(%37, %40, %25)
-//          -> (%27)
-//      -> (%27)
-//  After updating:
-// %62 : Tensor = prim::Loop(%10, %27, %bias.2)
-//    block0(%stream_idx.1 : int, %bias.3 : Tensor):
-//      %61 : Tensor = prim::Loop(%9, %27, %bias.3)
-//        block0(%i.1 : int, %bias.1 : Tensor):
-//          %36 : Tensor = aten::select(%bias.1, %26, %stream_idx.1)
-//          %59 : Tensor?[] = prim::ListConstruct(%55, %58)
-//          %60 : Tensor = aten::index_put(%bias.1, %59, %45, %25)
-//          -> (%27, %60)
-//      -> (%27, %61)
-// clang-format on
-void RegisterInplaceNodeInLoopBlocks(Value* orig_data, Value* new_data) {
-  Node* inplace_node = new_data->node();
-  Block* outer_block = inplace_node->owningBlock();
-  Node* outer_block_node = outer_block->owningNode();
-
-  if (nullptr == outer_block_node) {
-    return;
-  }
-
-  if (outer_block_node->kind() != prim::Loop)
-    return;
-
-  outer_block->registerOutput(new_data);
-  std::vector<std::pair<Block*, Node*>> node_list = {
-      std::make_pair(outer_block, outer_block_node)};
-
-  outer_block_node->addOutput()->setType(new_data->type());
-  auto next_block = outer_block_node->owningBlock();
-  auto next_node = outer_block_node;
-
-  while (nullptr != next_block->owningNode() &&
-         next_block != orig_data->node()->owningBlock()) {
-    outer_block = next_block;
-    outer_block->registerOutput(
-        next_node->outputs().at(next_node->outputs().size() - 1));
-    next_node = outer_block->owningNode();
-    next_node->addOutput()->setType(new_data->type());
-    next_block = next_node->owningBlock();
-    if (next_node->kind() == prim::Loop) // Do not register input if nested in
-                                         // If block. Register in Loop blocks.
-      node_list.emplace_back(std::make_pair(outer_block, next_node));
-  }
-
-  // Register inplace node inputs through the blocks.
-  auto next_data = orig_data;
-  while (!node_list.empty()) {
-    auto cur_pair = node_list.back();
-    // Add input to current node.
-    cur_pair.second->addInput(next_data);
-    // Add input to current block.
-    auto cur_input = cur_pair.first->addInput();
-    cur_input->setType(next_data->type());
-    next_data = cur_input;
-    node_list.pop_back();
-  }
-
-  // Update inplace node inputs inside the outer most block.
-  outer_block_node = outer_block->owningNode();
-  auto prev_data =
-      outer_block_node->inputs().at(outer_block_node->inputs().size() - 1);
-  for (auto node : inplace_node->owningBlock()->nodes()) {
-    size_t idx = 0;
-    for (auto inputs_ : node->inputs()) {
-      if (inputs_ == prev_data) {
-        node->replaceInput(idx, next_data);
-        break;
-      }
-      idx++;
-    }
-  }
-
-  orig_data->replaceAllUsesAfterNodeWith(
-      next_node->outputs().at(0)->node(),
-      next_node->outputs().at(next_node->outputs().size() - 1));
-}
-
-// Register inplace op node inputs/outputs through the blocks.
-void RegisterInplaceNodeInBlocks(Value* orig_data, Value* new_data) {
-  Node* inplace_node = new_data->node();
-  Block* outer_block = inplace_node->owningBlock();
-  Node* outer_block_node = outer_block->owningNode();
-
-  if (outer_block_node == nullptr)
-    return;
-
-  // Check if the value is already registered in the block
-  bool registered = false;
-  while (IsInplaceNode(orig_data->node())) {
-    orig_data = orig_data->node()->inputs().at(0);
-  }
-  for (auto use : orig_data->uses()) {
-    if ((use.user->owningBlock() == outer_block) &&
-        (use.user->isAfter(inplace_node))) {
-      size_t idx = 0;
-      for (auto input_ : use.user->inputs()) {
-        if (input_ == orig_data) {
-          use.user->replaceInput(idx, new_data);
-          registered = true;
-        }
-        idx++;
-      }
-    }
-  }
-  if (registered)
-    return;
-
-  // Register inplace node outputs through the blocks.
-  RegisterInplaceNodeInLoopBlocks(orig_data, new_data);
-
-  RegisterInplaceNodeInIfBlocks(orig_data, new_data, orig_data->debugName());
-
-  while (nullptr != outer_block->owningNode() &&
-         outer_block != orig_data->node()->owningBlock()) {
-    MatchIfBlocksOutputForValue(orig_data, outer_block, new_data);
-    outer_block = outer_block->owningNode()->owningBlock();
-  }
-}
-
-void PrepareIndexPutForONNX(Node* node) {
+std::pair<Value*, Value*> PrepareIndexPutForONNX(Node* node) {
   TORCH_INTERNAL_ASSERT(
       node->kind() == aten::index_put || node->kind() == aten::index_put_);
   auto placeholder_node = EncapsulatePatternIntoSubblock(node).value();
-  if (node->kind() == aten::index_put_) {
-    auto orig_data = placeholder_node->input();
-    auto new_data = placeholder_node->output();
-
-    if (nullptr == placeholder_node->owningBlock()->owningNode()) {
-      orig_data->replaceAllUsesAfterNodeWith(placeholder_node, new_data);
-      return;
-    }
-    RegisterInplaceNodeInBlocks(orig_data, new_data);
-  }
+  node->destroy();
+  return std::make_pair(placeholder_node->input(0), placeholder_node->output());
 }
 
-void PrepareCopyForONNX(Node* node) {
-  if (node->kind() == aten::copy_) {
-    // aten::copy_ can be viewed as a special case of index_put, where the
-    // tensor indices input is empty.
-    // Remove aten::copy_, and replace it with index_put.
-    // 1. create an empty listConstruct node as indices input for index_put.
-    // 2. create index_put node.
+std::pair<Value*, Value*> PrepareCopyForONNX(Node* node) {
+  TORCH_INTERNAL_ASSERT(node->kind() == aten::copy_);
+  // aten::copy_ can be viewed as a special case of index_put, where the
+  // tensor indices input is empty.
+  // Remove aten::copy_, and replace it with index_put.
+  // 1. create an empty listConstruct node as indices input for index_put.
+  // 2. create index_put node.
 
-    // Tracing aten::copy_ broadcasts the rhs values.
-    // 3. Apply broadcasting for scripting.
-    WithInsertPoint guard(node);
-    auto graph = node->owningGraph();
-    auto dummy_list =
-        graph->insertNode(graph->createList(OptionalType::ofTensor(), {}))
-            ->output();
+  // Tracing aten::copy_ broadcasts the rhs values.
+  // 3. Apply broadcasting for scripting.
+  WithInsertPoint guard(node);
+  auto graph = node->owningGraph();
+  auto dummy_list =
+      graph->insertNode(graph->createList(OptionalType::ofTensor(), {}))
+          ->output();
 
-    auto expanded_value =
-        graph->insert(aten::expand_as, {node->input(1), node->input(0)});
-    expanded_value->node()->setSourceRange(node->sourceRange());
-    expanded_value->copyMetadata(node->input(1));
+  auto expanded_value =
+      graph->insert(aten::expand_as, {node->input(1), node->input(0)});
+  expanded_value->node()->setSourceRange(node->sourceRange());
+  expanded_value->copyMetadata(node->input(1));
 
-    auto index_put = graph->insert(
-        aten::index_put_,
-        {node->input(0), dummy_list, expanded_value, node->input(2)});
-    index_put->node()->setSourceRange(node->sourceRange());
-    index_put->copyMetadata(node->output());
-    node->output()->replaceAllUsesWith(index_put);
+  auto index_put = graph->insert(
+      aten::index_put_,
+      {node->input(0), dummy_list, expanded_value, node->input(2)});
+  index_put->node()->setSourceRange(node->sourceRange());
+  index_put->copyMetadata(node->output());
+  node->output()->replaceAllUsesWith(index_put);
 
-    PrepareIndexPutForONNX(index_put->node());
-  }
+  node->destroy();
+
+  return PrepareIndexPutForONNX(index_put->node());
 }
 
-void PrepareInplaceOpsInBlocksForONNX(Node* node) {
+std::pair<Value*, Value*> PrepareInplaceOpsInBlocksForONNX(Node* node) {
   if (!node->kind().is_aten())
-    return;
+    return {};
 
   auto name = node->schema().name();
   bool inplace_op = name.at(name.size() - 1) == '_';
   if (!inplace_op)
-    return;
+    return {};
 
   auto new_schema = name.substr(0, name.size() - 1);
 
   Node* input_node = node->inputs().at(0)->node();
-  if (input_node->kind() != aten::select && input_node->kind() != aten::slice)
-    return;
 
   auto graph = node->owningGraph();
   auto new_node = graph->create(Symbol::fromQualString(new_schema), 1);
@@ -401,17 +209,26 @@
   new_node->output()->setType(node->output()->type());
   new_node->insertBefore(node);
   new_node->setSourceRange(node->sourceRange());
+  node->replaceAllUsesWith(new_node);
+  node->destroy();
 
-  auto false_val_ = graph->insertConstant(false);
+  if (input_node->kind() == aten::select || input_node->kind() == aten::slice) {
+    // Cases from a[i] = x. Convert to copy_ and eventually index_put_.
+    WithInsertPoint guard(new_node);
+    auto false_val_ = graph->insertConstant(false);
 
-  auto new_copy = graph->create(aten::copy_, 1);
-  new_copy->addInput(input_node->output());
-  new_copy->addInput(new_node->output());
-  new_copy->addInput(false_val_);
-  new_copy->insertBefore(node);
-  new_copy->setSourceRange(node->sourceRange());
+    auto new_copy = graph->create(aten::copy_, 1);
+    new_copy->addInput(new_node->inputs().at(0));
+    new_copy->addInput(new_node->output());
+    new_copy->addInput(false_val_);
+    new_copy->insertAfter(new_node);
+    new_copy->setSourceRange(new_node->sourceRange());
 
-  PrepareCopyForONNX(new_copy);
+    return PrepareCopyForONNX(new_copy);
+  } else {
+    // Direct aliasing, the node is a standalone inplace op.
+    return std::make_pair(new_node->input(0), new_node->output());
+  }
 }
 
 // aten::pop is inplace. The tensor list input is updated.
@@ -419,54 +236,43 @@
 // aten::pop. Then it makes the original aten::pop operator return the updated
 // tensor list, and replaces all later uses of that tensor list with this new
 // output.
-static void PrepareListPopForONNX(Node* n) {
-  if (n->kind() == aten::pop) {
-    //   %ten : Tensor = aten::pop(%seq, %pos)
-    // Convert to
-    //   %ten : Tensor = aten::__getitem__(%seq, %pos)
-    //   %new_seq : Tensor[] = aten::pop(%seq, %pos)
-    // And replace all uses of %seq afterwards with %new_seq
-    Node* getitem_node =
-        n->owningGraph()->create(aten::__getitem__, {n->inputs()});
-    getitem_node->output()->setType(n->output()->type());
-    getitem_node->insertBefore(n);
-    n->output()->replaceAllUsesWith(getitem_node->output());
-    n->output()->setType(n->inputs().at(0)->type());
+static std::pair<Value*, Value*> PrepareListPopForONNX(Node* n) {
+  TORCH_INTERNAL_ASSERT(n->kind() == aten::pop);
+  //   %ten : Tensor = aten::pop(%seq, %pos)
+  // Convert to
+  //   %ten : Tensor = aten::__getitem__(%seq, %pos)
+  //   %new_seq : Tensor[] = aten::pop(%seq, %pos)
+  // And replace all uses of %seq afterwards with %new_seq
+  Node* getitem_node =
+      n->owningGraph()->create(aten::__getitem__, {n->inputs()});
+  getitem_node->output()->setType(n->output()->type());
+  getitem_node->insertBefore(n);
+  n->output()->replaceAllUsesWith(getitem_node->output());
+  n->output()->setType(n->inputs().at(0)->type());
 
-    if (nullptr == n->owningBlock()->owningNode()) {
-      n->inputs().at(0)->replaceAllUsesAfterNodeWith(n, n->output());
-      return;
-    }
-    RegisterInplaceNodeInBlocks(n->inputs().at(0), n->output());
-  }
+  return std::make_pair(n->input(0), n->output());
 }
 
-static void PrepareListDeleteForONNX(Node* n) {
-  if (n->kind() == aten::Delete) {
+static std::pair<Value*, Value*> PrepareListDeleteForONNX(Node* n) {
+  TORCH_INTERNAL_ASSERT(n->kind() == aten::Delete);
+  n->addOutput();
+  n->output()->setType(n->inputs().at(0)->type());
+
+  return std::make_pair(n->input(0), n->output());
+}
+
+static std::pair<Value*, Value*> PrepareListAppendAndInsertForONNX(Node* n) {
+  TORCH_INTERNAL_ASSERT(n->kind() == aten::insert || n->kind() == aten::append);
+  if (n->outputs().size() == 0) {
     n->addOutput();
     n->output()->setType(n->inputs().at(0)->type());
-
-    if (nullptr == n->owningBlock()->owningNode()) {
-      n->inputs().at(0)->replaceAllUsesAfterNodeWith(n, n->output());
-      return;
-    }
-    RegisterInplaceNodeInBlocks(n->inputs().at(0), n->output());
   }
+  return std::make_pair(n->input(0), n->output());
 }
 
-static void PrepareListAppendAndInsertForONNX(Node* n) {
-  if (n->kind() == aten::insert || n->kind() == aten::append) {
-    if (n->outputs().size() == 0) {
-      n->addOutput();
-      n->output()->setType(n->inputs().at(0)->type());
-    }
-
-    if (nullptr == n->owningBlock()->owningNode()) {
-      n->inputs().at(0)->replaceAllUsesAfterNodeWith(n, n->output());
-      return;
-    }
-    RegisterInplaceNodeInBlocks(n->inputs().at(0), n->output());
-  }
+static std::pair<Value*, Value*> PrepareListSetItemForONNX(Node* n) {
+  TORCH_INTERNAL_ASSERT(n->kind() == aten::_set_item);
+  return std::make_pair(n->input(0), n->output());
 }
 
 // Remove Mutation pass does not handle mutation on block inputs.
@@ -567,256 +373,466 @@
       name);
 }
 
-Value* registerSetAttrInBlocks(
-    const std::shared_ptr<Graph>& graph,
-    Block* block,
-    Node* cloneNode,
-    Value* origValue,
-    const std::string& output_name) {
-  RegisterInplaceNodeInLoopBlocks(origValue, cloneNode->output());
-  RegisterInplaceNodeInIfBlocks(origValue, cloneNode->output(), output_name);
-
-  Value* output = nullptr;
-  while (nullptr != block->owningNode() &&
-         block != origValue->node()->owningBlock()) {
-    output = MatchIfBlocksOutputForValue(origValue, block, cloneNode->output());
-    block = block->owningNode()->owningBlock();
-  }
-  return output;
+void InplaceConverter::ValueTracker::init(const std::shared_ptr<Graph>& graph) {
+  alias_to_value_ = {};
+  value_to_sorted_aliases_ = {};
+  graph_ = graph;
 }
 
-// clang-format off
-// The trackAndRegisterAttributesInBlocks function tracks any instances
-// of getAttr and setAttr in a sub-block and capture these nodes as inpalce
-// read/write ops. This pass captures the output of setAttr in sub-block outputs
-// so that it gets reflected into the outer block.
-// Also, the pass matched the number of If sub-block outputs
-// if a value is updated in one branch, but no updated on the other branch.
-// For example:
-//= prim::If(%12)
-//    block0():
-//      %13 : __torch__.torch.nn.modules.conv.___torch_mangle_9.Conv1d = prim::GetAttr[name="conv"](%3)
-//      %b.1 : Tensor? = prim::GetAttr[name="bias"](%13)
-//      ...
-//      %18 : __torch__.torch.nn.modules.conv.___torch_mangle_9.Conv1d = prim::GetAttr[name="conv"](%3)
-//      %19 : Tensor = aten::add(%anchors.1, %b, %6)
-//       = prim::SetAttr[name="bias"](%18, %19)
-//     -> ()
-//    block1():
-//      %20 : __torch__.torch.nn.modules.conv.___torch_mangle_9.Conv1d = prim::GetAttr[name="conv"](%3)
-//      %21 : __torch__.torch.nn.modules.conv.___torch_mangle_9.Conv1d = prim::GetAttr[name="conv"](%3)
-//      %22 : Tensor = prim::GetAttr[name="weight"](%21)
-//      %23 : Tensor = aten::slice(%22, %7, %7, %8, %6)
-//       = prim::SetAttr[name="bias"](%20, %23)
-//     -> ()
-// After the pass
-//%_output_conv.bias.3 : Tensor = prim::If(%12)
-//    block0():
-//     ...
-//      %18 : __torch__.torch.nn.modules.conv.___torch_mangle_9.Conv1d = prim::GetAttr[name="conv"](%3)
-//      %19 : Tensor = aten::add(%anchors.1, %b, %6)
-//      %_output_conv.bias.2 : Tensor = aten::clone(%19, %26)
-//     -> (%_output_conv.bias.2)
-//    block1():
-//      %20 : __torch__.torch.nn.modules.conv.___torch_mangle_9.Conv1d = prim::GetAttr[name="conv"](%3)
-//      %23 : Tensor = aten::slice(%conv.weight, %7, %7, %8, %6)
-//      %31 : None = prim::Constant()
-//      %_output_conv.bias.4 : Tensor = aten::clone(%23, %31)
-//     -> (%_output_conv.bias.4)
-// clang-format on
-void trackAndRegisterAttributesInBlocks(
-    Node* n,
-    const std::shared_ptr<Graph>& graph,
-    const Module& module_,
-    std::unordered_map<std::string, Value*>& allAttrValues,
-    std::unordered_map<std::string, Value*>& setAttrValues,
-    std::unordered_map<std::string, Value*>& nextSetAttrValues) {
-  if (n->kind() != prim::GetAttr && n->kind() != prim::SetAttr)
-    return;
+std::string InplaceConverter::ValueTracker::toString() const {
+  std::stringstream ss;
 
-  auto name = n->s(attr::name);
-  auto attrModule = module_;
-  Value* paramConst = nullptr;
-
-  auto moduleNames =
-      findSubModuleAttr(n->inputs().at(0), name, attrModule, graph);
-
-  std::string fullName("");
-  for (auto& name : moduleNames) {
-    fullName += name + '.';
+  // ss << "Current graph: " << graph_->toString() << std::endl;
+  ss << "Tracking " << value_to_sorted_aliases_.size() << " individual values."
+     << std::endl;
+  ss << "value_to_sorted_aliases_: " << std::endl;
+  size_t idx = 0;
+  for (const auto& it : value_to_sorted_aliases_) {
+    ss << "Value[" << idx << "]: " << it.first->debugName() << std::endl;
+    ss << "  Mapping to ";
+    for (auto v : it.second) {
+      ss << v->debugName() << " ";
+    }
+    ss << std::endl;
+    idx++;
   }
-  fullName += name;
 
-  if (allAttrValues.find(fullName) == allAttrValues.end() &&
-      attrModule.hasattr(name)) {
-    auto attr = attrModule.attr(name);
-    auto type = attrModule.type();
-    auto slot = *type->findAttributeSlot(name);
+  ss << "alias_to_value_: " << std::endl;
+  for (auto it : alias_to_value_) {
+    ss << "  Alias " << it.first->debugName();
+    ss << " map to " << it.second->debugName() << std::endl;
+  }
 
-    // Add model_parameters and model_buffers as model inputs. Order is
-    // preserved based on the appearance in the graph.
-    if (type->is_parameter(slot) || type->is_buffer(slot) ||
-        (attr.isObject() && !attr.toObjectRef().type()->is_module())) {
-      if (allAttrValues.find(fullName) == allAttrValues.end()) {
-        paramConst = findArgumentAsInputParam(graph, fullName, attr);
-        allAttrValues.insert({fullName, paramConst});
-      }
-    } else if (auto attrVal = tryInsertConstant(*graph, attr)) {
-      for (size_t i = 0; i < type->getAttributes().size(); i++) {
-        if (type->getAttributeName(i) == name) {
-          paramConst = *attrVal;
-          allAttrValues.insert({fullName, paramConst});
+  return ss.str();
+}
+
+void InplaceConverter::ValueTracker::recordSetValue(
+    Value* old_v,
+    Value* new_v) {
+  GRAPH_UPDATE(
+      "Calling recordSetValue with old_v: ",
+      old_v->debugName(),
+      " new_v: ",
+      new_v->debugName());
+  GRAPH_UPDATE(this->toString());
+  auto* n = new_v->node();
+  auto* owning_block = n->owningBlock();
+
+  if (alias_to_value_.find(old_v) == alias_to_value_.end()) {
+    alias_to_value_[old_v] = old_v;
+    value_to_sorted_aliases_[old_v] = {old_v};
+  }
+
+  auto root_v = alias_to_value_[old_v];
+  alias_to_value_[new_v] = root_v;
+  auto& sorted_alias = value_to_sorted_aliases_[root_v];
+  sorted_alias.insert(new_v);
+
+  // check if new_v is created inside if or loop subblock.
+  auto* owning_blocknode = owning_block->owningNode();
+  if (nullptr == owning_blocknode) {
+    return;
+  }
+  auto owning_block_nkind = owning_blocknode->kind();
+  if (owning_block_nkind != prim::Loop && owning_block_nkind != prim::If) {
+    return;
+  }
+
+  bool registered = std::any_of(
+      owning_block->outputs().begin(),
+      owning_block->outputs().end(),
+      [&sorted_alias](Value* out) {
+        return std::any_of(
+            sorted_alias.begin(), sorted_alias.end(), [&out](Value* alias) {
+              return alias == out;
+            });
+      });
+
+  bool from_outer_alias = std::any_of(
+      sorted_alias.begin(),
+      sorted_alias.end(),
+      [&owning_blocknode](Value* alias) {
+        return isAncestor(
+            alias->node()->owningBlock(), owning_blocknode->owningBlock());
+      });
+
+  // The data of this value has been changed.
+  // If this value has alias from outer block,
+  // then the update must be reflected back to outside.
+  // Thus it needs to be registered as a subblock output.
+  // This step can be skipped if other alias of this value has already been
+  // registered as sublock output.
+  if (!registered && from_outer_alias) {
+    if (owning_block_nkind == prim::Loop) {
+      owning_block->registerOutput(new_v);
+      auto new_block_in = owning_block->addInput();
+      new_block_in->setType(new_v->type());
+      sorted_alias.insert(new_block_in);
+      alias_to_value_[new_block_in] = root_v;
+      owning_blocknode->addInput(root_v);
+    } else if (owning_block_nkind == prim::If) {
+      for (auto* if_sub_block : owning_blocknode->blocks()) {
+        if (owning_block == if_sub_block) {
+          if_sub_block->registerOutput(new_v);
+        } else {
+          if_sub_block->registerOutput(root_v);
         }
       }
-    } else {
-      GRAPH_DEBUG(
-          attr.type()->cast<ClassType>() ? "" : "attribute: ",
-          name,
-          " is not materializable.");
-      return;
     }
+    auto* new_blocknode_out = owning_blocknode->addOutput();
+    new_blocknode_out->setType(new_v->type());
+    recordSetValue(root_v, new_blocknode_out);
   }
 
-  if (n->kind() == prim::SetAttr) { // Handle SetAttr node
-    if (attrModule.hasattr(name)) {
-      // If inside a block, keep the output value to register in block
-      // output.
-      auto block_ = n->owningBlock();
-      Node* cloneNode =
-          addDummyClone(block_->owningGraph(), n->inputs().at(1), true, n);
-      if (block_->owningNode() &&
-          (block_->owningNode()->kind() == prim::If ||
-           block_->owningNode()->kind() == prim::Loop)) {
-        auto attrValue = (setAttrValues.find(fullName) != setAttrValues.end())
-            ? setAttrValues[fullName]
-            : allAttrValues[fullName];
-
-        auto blockOutput = registerSetAttrInBlocks(
-            graph, block_, cloneNode, attrValue, fullName);
-
-        nextSetAttrValues[fullName] = blockOutput;
-      }
-      // SetAttr writes a value to an attr. Keep this
-      // in the setAttrValues map.
-      setAttrValues[fullName] = cloneNode->output();
-    }
-  } else if (n->kind() == prim::GetAttr) { // Handle GetAttr node
-    if (setAttrValues.find(fullName) != setAttrValues.end()) {
-      // Attr has been set earlier in the graph.
-      // Read its value from setAttrValues map.
-      auto set_attr_node_input = setAttrValues[fullName];
-      // Clone SetAttr input
-      n->output()->replaceAllUsesAfterNodeWith(n, set_attr_node_input);
-    } else if (allAttrValues.find(fullName) != allAttrValues.end()) {
-      // Attr has not been set earlier in the graph. Replace it with the
-      // graph parameter if exists.
-      n->output()->replaceAllUsesWith(allAttrValues[fullName]);
-      n->removeAllInputs();
-    }
-  }
+  GRAPH_UPDATE(
+      "After recordSetValue for in: ",
+      old_v->debugName(),
+      ", out: ",
+      new_v->debugName(),
+      ". tracker status:");
+  GRAPH_UPDATE(this->toString());
 }
 
-// clang-format off
-// The registerInplaceOpAsBlockOutputs function tracks inplace op
-// (like aten::copy_ or aten::append) outputs as sub-block output.
-// Also, match the number of If sub-block outputs
-// if a value is updated in one branch, but no updated on the other branch.
-// For example:
-// = prim::If(%30)
-//    block0():
-//      ...
-//      %35 : Tensor = aten::copy_(%state_copy.1, %33, %12)
-//      -> ()
-//    block1():
-//      ...
-//      %40 : Tensor = aten::copy_(%state.1, %38, %12)
-//      -> ()
-//
-// After the pass
-//%_output_state_copy.1 : Tensor, %_output_state.1 : Tensor = prim::If(%30)
-//    block0():
-//      %_output_state.2 : Tensor = aten::clone(%state.1, %59)
-//      ...
-//      %_output_state_copy.3 : Tensor = onnx::Placeholder[name="index_put_"](%state_copy.1)...
-//      ...
-//      -> (%_output_state_copy.3, %_output_state.2)
-//    block1():
-//      %50 : None = prim::Constant()
-//      %_output_state_copy.2 : Tensor = aten::clone(%state_copy.1, %50)
-//      ...
-//      %_output_state.3 : Tensor = onnx::Placeholder[name="index_put_"](%state.1)...
-//       ...
-//      -> (%_output_state_copy.2, %_output_state.3)
-std::unordered_map<std::string, Value*> registerInplaceOpAsBlockOutputs(
-    Block* block,
-    const std::shared_ptr<Graph>& graph,
-    std::unordered_map<std::string, Value*>& allAttrValues,
-    std::unordered_map<std::string, Value*>& setAttrValues,
-    MutationRemover& mr,
-    Module* module_ = nullptr) {
-  Node* m = *block->nodes().begin();
-  WithInsertPoint guard(m);
-  std::unordered_map<std::string, Value*> nextSetAttrValues = {};
+// Based on current value aliases record, pass over graph and correct alias
+// reference for all the nodes.
+void InplaceConverter::correctAliasReferences() {
+  correctAliasReferences(graph_->block());
+}
 
+void InplaceConverter::correctAliasReferences(Block* block) {
   for (auto it = block->nodes().begin(); it != block->nodes().end();) {
     Node* n = *it;
     it++; // node n can be destroyed
 
-    if (nullptr != module_ &&
-        (n->kind() == prim::GetAttr || n->kind() == prim::SetAttr)) {
-      Module moduleClone = (*module_);
-      trackAndRegisterAttributesInBlocks(
-          n,
-          graph,
-          moduleClone,
-          allAttrValues,
-          setAttrValues,
-          nextSetAttrValues);
-    } else if (n->kind() == aten::copy_) {
-      PrepareCopyForONNX(n);
-    } else if (n->kind() == aten::index_put || n->kind() == aten::index_put_) {
-      PrepareIndexPutForONNX(n);
-    } else if (mr.inplaceOpVariant(n)) {
-      PrepareInplaceOpsInBlocksForONNX(n);
-    } else if (n->kind() == aten::pop) {
-      PrepareListPopForONNX(n);
-    } else if (n->kind() == aten::insert || n->kind() == aten::append) {
-      PrepareListAppendAndInsertForONNX(n);
-    } else if (n->kind() == aten::Delete) {
-      PrepareListDeleteForONNX(n);
-    } else { // for prim::If and prim::Loop nodes with blocks.
-      for (Block* sub_block : n->blocks()) {
-        std::unordered_map<std::string, Value*> map_ =
-            registerInplaceOpAsBlockOutputs(
-                sub_block, graph, allAttrValues, setAttrValues, mr, module_);
-        std::unordered_map<std::string, Value*>::iterator mapIt;
-        for (mapIt = map_.begin(); mapIt != map_.end(); mapIt++) {
-          setAttrValues[mapIt->first] = mapIt->second;
-        }
+    correctAliasReferences(n);
+
+    auto nkind = n->kind();
+    if (nkind == prim::If || nkind == prim::Loop) {
+      for (auto* sub_block : n->blocks()) {
+        correctAliasReferences(sub_block);
       }
     }
   }
-  return nextSetAttrValues;
+  correctAliasReferences(block->return_node());
 }
 
-// Register Inplace Ops As Block Outputs
-// Inplace operations like aten::copy_ or aten::append that are inside
-// sub-blocks would require the output of the operation to be captured
-// as sub-block output, so that the inplace operation would be visible
-// to the outer block.
-// We also consider setAttr node an inplace op, and handle those
-// similarly by tracking the output as sub-block outputs.
-void RegisterInplaceOpAsBlockOutputs(
-    Module* module,
-    const std::shared_ptr<Graph>& graph,
-    MutationRemover& mr) {
-  // A map of names and values of referenced attributes, to avoid duplicates.
-  std::unordered_map<std::string, Value*> allAttrValues = {};
-  // A map of names and values of set attributes, to track mutations.
-  std::unordered_map<std::string, Value*> setAttrValues = {};
+// For every input of Node n, find the correct alias representing that input.
+void InplaceConverter::correctAliasReferences(Node* n) {
+  for (size_t i = 0; i < n->inputs().size(); ++i) {
+    auto* in = n->input(i);
+    auto* alias = vt_.findAliasForValueAtNode(in, n);
 
-  registerInplaceOpAsBlockOutputs(
-      graph->block(), graph, allAttrValues, setAttrValues, mr, module);
+    if (alias != in) {
+      n->replaceInput(i, alias);
+      GRAPH_UPDATE(
+          "Replacing ",
+          in->debugName(),
+          " with ",
+          alias->debugName(),
+          " for ",
+          *n);
+    }
+  }
+}
+
+// Find the correct alias representing Value v at Node n.
+Value* InplaceConverter::ValueTracker::findAliasForValueAtNode(
+    Value* v,
+    const Node* n) const {
+  GRAPH_UPDATE("Finding alias for value:", v->debugName(), " at node ", *n);
+  if (alias_to_value_.find(v) == alias_to_value_.end()) {
+    // This value was not affected by any inplace operator.
+    return v;
+  }
+
+  auto* root_v = alias_to_value_.find(v)->second;
+  TORCH_INTERNAL_ASSERT(
+      value_to_sorted_aliases_.find(root_v) != value_to_sorted_aliases_.end());
+  const auto& aliases = value_to_sorted_aliases_.find(root_v)->second;
+
+  // alias is accessible only if
+  // 1. alias owning block is ancestor of n.
+  // 2. alias owning node is before n.
+  // return the last alias that satisfies this condition.
+  Value* found_alias = nullptr;
+  for (auto* alias : aliases) {
+    auto* alias_n = alias->node();
+    if (alias_n->isBefore(n) &&
+        isAncestor(alias_n->owningBlock(), n->owningBlock())) {
+      found_alias = alias;
+    }
+  }
+
+  TORCH_INTERNAL_ASSERT(
+      nullptr != found_alias,
+      "More details: \n",
+      n->sourceRange().str(),
+      "Input ",
+      v->debugName(),
+      " of node ",
+      *n,
+      " was modified by in-place operation, but we cannot find its updated value. ",
+      "Please report a bug to PyTorch, and/or try to avoid using in-place operators on this value.");
+
+  return found_alias;
+}
+
+// Pass over block, and gather the initial value for any attribute.
+// Also cache the full name of the attribute for every GetAttr/SetAttr node.
+void InplaceConverter::gatherAttrNameInitialValueMap(
+    Block* block,
+    std::unordered_map<std::string, Value*>& attr_name_value_map,
+    std::unordered_map<Node*, std::string>& attr_node_fullname_map) {
+  for (auto it = block->nodes().begin(); it != block->nodes().end();) {
+    Node* n = *it;
+    it++; // node n can be destroyed
+
+    for (auto* sub_block : n->blocks()) {
+      gatherAttrNameInitialValueMap(
+          sub_block, attr_name_value_map, attr_node_fullname_map);
+    }
+
+    if (n->kind() != prim::GetAttr && n->kind() != prim::SetAttr)
+      continue;
+
+    auto name = n->s(attr::name);
+    auto attrModule = *module_;
+    Value* paramConst = nullptr;
+
+    auto moduleNames =
+        findSubModuleAttr(n->inputs().at(0), name, attrModule, graph_);
+
+    std::string fullName("");
+    for (auto& name : moduleNames) {
+      fullName += name + '.';
+    }
+    fullName += name;
+
+    attr_node_fullname_map.insert({n, fullName});
+
+    if (attr_name_value_map.find(fullName) == attr_name_value_map.end() &&
+        attrModule.hasattr(name)) {
+      auto attr = attrModule.attr(name);
+      auto type = attrModule.type();
+      auto slot = *type->findAttributeSlot(name);
+
+      // Add model_parameters and model_buffers as model inputs. Order is
+      // preserved based on the appearance in the graph.
+      WithInsertPoint guard(graph_->nodes().front());
+      if (type->is_parameter(slot) || type->is_buffer(slot) ||
+          (attr.isObject() && !attr.toObjectRef().type()->is_module())) {
+        paramConst = findArgumentAsInputParam(graph_, fullName, attr);
+        attr_name_value_map.insert({fullName, paramConst});
+      } else if (auto attrVal = tryInsertConstant(*graph_, attr)) {
+        // TODO: Extend support for attribute of type List[Tensor] etc.
+        for (size_t i = 0; i < type->getAttributes().size(); i++) {
+          if (type->getAttributeName(i) == name) {
+            paramConst = *attrVal;
+            attr_name_value_map.insert({fullName, paramConst});
+          }
+        }
+      } else {
+        // If attribute is a custom class object, instead of primitive types,
+        // Tensor, or List/Tuple/Dict of Tensors.
+        GRAPH_DEBUG(
+            attr.type()->cast<ClassType>() ? "" : "attribute: ",
+            name,
+            " is not materializable.");
+      }
+    }
+
+    // Create dummy initial value, if initial value does not exist for this
+    // attribute.
+    if (attr_name_value_map.find(fullName) == attr_name_value_map.end()) {
+      auto* noneNode = graph_->create(prim::Constant);
+      noneNode->output()->setType(NoneType::get());
+      noneNode->insertBefore(graph_->nodes().front());
+      attr_name_value_map.insert({fullName, noneNode->output()});
+    }
+  }
+}
+
+// Replace prim::GetAttr and prim::SetAttr with ATen inplace operators.
+// Example graph:
+// clang-format off
+//  Before graph(%x.1 : Float(12, strides=[1], requires_grad=0, device=cpu)):
+//    %1 : __torch__.___torch_mangle_1.M = prim::CreateObject()
+//    ...
+//    %10 : Tensor = aten::arange(%6, %7, %7, %7, %7)
+//     = prim::SetAttr[name="_bias"](%1, %10)
+//     = prim::Loop(%5, %8)
+//      block0(%i.1 : int):
+//        %12 : bool = aten::eq(%i.1, %4)
+//         = prim::If(%12)
+//          block0():
+//             = prim::Loop(%3, %8)
+//              block0(%j : int):
+//                %14 : Tensor = prim::GetAttr[name="_bias"](%1)
+//                %15 : Tensor = aten::add_(%14, %2, %9)
+//                 = prim::SetAttr[name="_bias"](%1, %15)
+//                -> (%8)
+//            -> ()
+//          block1():
+//            %16 : Tensor = aten::arange(%6, %7, %7, %7, %7)
+//             = prim::SetAttr[name="_bias"](%1, %16)
+//            -> ()
+//        -> (%8)
+//    %17 : Tensor = prim::GetAttr[name="_bias"](%1)
+//    %18 : Tensor = aten::add(%17, %x.1, %9)
+//    return (%18)
+//
+//  After graph(%x.1 : Float(12, strides=[1], requires_grad=0, device=cpu)):
+//    %19 : Float(2, strides=[1], requires_grad=0, device=cpu) = prim::Constant[value= 1  1 [ CPUFloatType{2} ]]()
+//    %1 : __torch__.___torch_mangle_1.M = prim::CreateObject()
+//    ...
+//    %10 : Tensor = aten::arange(%6, %7, %7, %7, %7)
+//    %26 : bool = prim::Constant[value=0]()
+//    %27 : Tensor?[] = prim::ListConstruct()
+//    %28 : Tensor = aten::index_put_(%19, %27, %10, %26)
+//     = prim::Loop(%5, %8)
+//      block0(%i.1 : int):
+//        %12 : bool = aten::eq(%i.1, %4)
+//         = prim::If(%12)
+//          block0():
+//             = prim::Loop(%3, %8)
+//              block0(%j : int):
+//                %15 : Tensor = aten::add_(%19, %2, %9)
+//                %23 : bool = prim::Constant[value=0]()
+//                %24 : Tensor?[] = prim::ListConstruct()
+//                %25 : Tensor = aten::index_put_(%19, %24, %15, %23)
+//                -> (%8)
+//            -> ()
+//          block1():
+//            %16 : Tensor = aten::arange(%6, %7, %7, %7, %7)
+//            %20 : bool = prim::Constant[value=0]()
+//            %21 : Tensor?[] = prim::ListConstruct()
+//            %22 : Tensor = aten::index_put_(%19, %21, %16, %20)
+//            -> ()
+//        -> (%8)
+//    %18 : Tensor = aten::add(%19, %x.1, %9)
+//    return (%18)
+// clang-format on
+void InplaceConverter::replaceAttrWithInplaceOps(
+    Block* block,
+    const std::unordered_map<std::string, Value*>& attr_name_value_map,
+    const std::unordered_map<Node*, std::string>& attr_node_fullname_map) {
+  for (const auto& pair : attr_node_fullname_map) {
+    auto* n = pair.first;
+    auto fullName = pair.second;
+    auto find_init_val = attr_name_value_map.find(fullName);
+    TORCH_INTERNAL_ASSERT(find_init_val != attr_name_value_map.end());
+
+    TORCH_INTERNAL_ASSERT(
+        n->kind() == prim::GetAttr || n->kind() == prim::SetAttr);
+    if (n->kind() == prim::SetAttr) {
+      // Convert SetAttr to inplace op.
+      // Directly convert to index_put_ instead of copy_, since we know expand
+      // is not required for value.
+      WithInsertPoint guard(n);
+      auto false_val_ = graph_->insertConstant(false);
+      auto dummy_list =
+          graph_->insertNode(graph_->createList(OptionalType::ofTensor(), {}))
+              ->output();
+
+      auto* index_put_node = graph_->create(aten::index_put_, 1);
+      index_put_node->addInput(find_init_val->second);
+      index_put_node->addInput(dummy_list);
+      index_put_node->addInput(n->input(1));
+      index_put_node->addInput(false_val_);
+      index_put_node->setSourceRange(n->sourceRange());
+      index_put_node->insertBefore(n);
+    } else if (n->kind() == prim::GetAttr) {
+      // Replace use of GetAttr with first seen alias (usually initial value) of
+      // that particular value. Correct alias at point of this node will be
+      // discovered and assigned in later pass.
+      n->output()->replaceAllUsesWith(find_init_val->second);
+    }
+
+    n->destroy();
+  }
+}
+
+void InplaceConverter::convertGetSetAttrToInplaceOps(Block* block) {
+  std::unordered_map<std::string, Value*> attr_name_value_map = {};
+  std::unordered_map<Node*, std::string> attr_node_fullname_map = {};
+  // First pass over graph, to gather all attribute names, and their intial
+  // values. Create dummy initial values for attributes if necessary. By the end
+  // of this pass, these dummy initial values should have zero uses, and can be
+  // safely removed. Otherwise it will imply error in model for using
+  // uninitialized values.
+  gatherAttrNameInitialValueMap(
+      block, attr_name_value_map, attr_node_fullname_map);
+  GRAPH_UPDATE("Graph after gatherAttrNameInitialValueMap", graph_->toString());
+
+  // Second pass over graph,
+  // replace GetAttr with first seen alias (usually initial value),
+  // and replace SetAttr with inplace op, updating new value onto first seen
+  // alias.
+  replaceAttrWithInplaceOps(block, attr_name_value_map, attr_node_fullname_map);
+}
+
+// Convert inplace ops to outplace version, and record the associated new alias
+// in ValueTracker.
+void InplaceConverter::convertInplaceOpsAndTrackAlias(Block* block) {
+  for (auto it = block->nodes().begin(); it != block->nodes().end();) {
+    Node* n = *it;
+    it++; // node n can be destroyed
+
+    auto nkind = n->kind();
+    if (nkind == prim::If || nkind == prim::Loop) {
+      for (Block* sub_block : n->blocks()) {
+        convertInplaceOpsAndTrackAlias(sub_block);
+      }
+    } else {
+      Value *orig_data = nullptr, *new_out = nullptr;
+      if (nkind == aten::copy_) {
+        std::tie(orig_data, new_out) = PrepareCopyForONNX(n);
+      } else if (nkind == aten::index_put || nkind == aten::index_put_) {
+        std::tie(orig_data, new_out) = PrepareIndexPutForONNX(n);
+        if (nkind == aten::index_put) {
+          // special case, index_put is not inplace.
+          continue;
+        }
+      } else if (nkind == aten::insert || nkind == aten::append) {
+        std::tie(orig_data, new_out) = PrepareListAppendAndInsertForONNX(n);
+      } else if (mr_->inplaceOpVariant(n)) {
+        std::tie(orig_data, new_out) = PrepareInplaceOpsInBlocksForONNX(n);
+      } else if (nkind == aten::pop) {
+        std::tie(orig_data, new_out) = PrepareListPopForONNX(n);
+      } else if (nkind == aten::Delete) {
+        std::tie(orig_data, new_out) = PrepareListDeleteForONNX(n);
+      } else if (nkind == aten::_set_item) {
+        std::tie(orig_data, new_out) = PrepareListSetItemForONNX(n);
+      } else {
+        // Not inplace op.
+        continue;
+      }
+
+      if (nullptr != orig_data && nullptr != new_out) {
+        vt_.recordSetValue(orig_data, new_out);
+      }
+    }
+  }
+}
+
+void InplaceConverter::convertInplaceOpsAndTrackAlias() {
+  convertInplaceOpsAndTrackAlias(graph_->block());
+  GRAPH_UPDATE(
+      "Graph after convertInplaceOpsAndTrackAlias: ", graph_->toString());
+  GRAPH_UPDATE(vt_.toString());
+}
+
+void InplaceConverter::convertMutationForONNX() {
+  // First pass to convert all prim::GetAttr and prim::SetAttr to ATen inplace
+  // operators.
+  convertGetSetAttrToInplaceOps(graph_->block());
+  GRAPH_UPDATE("Graph after convertGetSetAttrToInplaceOps", graph_->toString());
+  vt_.init(graph_);
+  // Second pass to convert all inplace operators to outplace version, and
+  // record the associated new alias in ValueTracker.
+  convertInplaceOpsAndTrackAlias();
+  // Third pass to check and correct alias reference for all the nodes.
+  correctAliasReferences();
 }
 
 } // namespace
@@ -829,7 +845,8 @@
   PrepareForRemoveMutations(mr, graph->block());
   RemoveTensorMutation(graph);
   RemoveListMutation(graph);
-  RegisterInplaceOpAsBlockOutputs(model, graph, mr);
+  InplaceConverter ic(graph, &mr, model);
+  ic.convertMutationForONNX();
 }
 
 } // namespace jit
diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py
index 050b952..39282bdf 100644
--- a/torch/onnx/symbolic_opset11.py
+++ b/torch/onnx/symbolic_opset11.py
@@ -256,6 +256,9 @@
         from torch.onnx.symbolic_opset9 import __getitem_ as getitem
         return getitem(g, self, i)
 
+def _set_item(g, tensor_list, i, v):
+    tensor_list = g.op("SequenceErase", tensor_list, i)
+    return g.op("SequenceInsert", tensor_list, v, i)
 
 def append(g, self, tensor):
     return g.op("SequenceInsert", self, tensor)
diff --git a/torch/onnx/symbolic_opset9.py b/torch/onnx/symbolic_opset9.py
index 08f7309..c41e484 100644
--- a/torch/onnx/symbolic_opset9.py
+++ b/torch/onnx/symbolic_opset9.py
@@ -1115,6 +1115,10 @@
     return wrap_with_not
 
 
+def __not_(g, self):
+    return g.op("Not", self)
+
+
 def eq(g, self, other):
     return g.op("Equal", self, other)
 
@@ -1463,10 +1467,21 @@
 
 
 def index_put(g, self, indices_list_value, values, accumulate):
-    if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
+    if sym_help._is_packed_list(indices_list_value):
         indices_list = sym_help._unpack_list(indices_list_value)
+    else:
+        indices_list = [indices_list_value]
+    if sym_help._operator_export_type == torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK:
         args = [self] + indices_list + [values, accumulate]
         return g.op("ATen", *args, operator_s='index_put')
+
+    accumulate = sym_help._parse_arg(accumulate, 'b')
+
+    if len(indices_list) == 0:
+        if accumulate:
+            return add(g, self, values)
+        else:
+            return values
     else:
         sym_help._onnx_opset_unsupported('index_put', 9, 11)