[JIT] handle specially mapped ops (#41503)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/41503

Fix for https://github.com/pytorch/pytorch/issues/41192

We can map fill_ and zero_ to their functional equivalents full_like and zeros_like

Test Plan: Imported from OSS

Reviewed By: jamesr66a

Differential Revision: D22629269

Pulled By: eellison

fbshipit-source-id: f1c62684dc55682c0b3845022e0461ec77d07179
diff --git a/test/jit/test_remove_mutation.py b/test/jit/test_remove_mutation.py
index 24da317..ef408e7 100644
--- a/test/jit/test_remove_mutation.py
+++ b/test/jit/test_remove_mutation.py
@@ -148,6 +148,33 @@
         self.run_pass('remove_mutation', foo.graph)
         FileCheck().check("aten::add_").run(foo.graph)
 
+    def test_special_mapped_op(self):
+        def test_successful():
+            x = torch.tensor([2, 2])
+            y = torch.tensor([2, 4])
+            x.zero_()
+            y.fill_(3)
+            return x, y
+
+        fn = torch.jit.script(test_successful)
+        graph = fn.graph
+        self.run_pass('remove_mutation', graph)
+        FileCheck().check_not("aten::zero_").check_not("aten::fill_").run(graph)
+        self.assertEqual(test_successful(), fn())
+
+        # full_like is not implemented for a tensor fill value
+
+        def test_unsuccessful():
+            x = torch.tensor([2, 2])
+            y = torch.tensor([2, 4])
+            x.fill_(y)
+            return x + x
+
+        fn = torch.jit.script(test_unsuccessful)
+        graph = fn.graph
+        self.run_pass('remove_mutation', graph)
+        FileCheck().check('aten::fill_').run(graph)
+
     def test_lists_append(self):
         def successful_remove():
             return [i for i in range(5)]  # noqa: C416
diff --git a/torch/csrc/jit/passes/remove_mutation.cpp b/torch/csrc/jit/passes/remove_mutation.cpp
index f1ac043..19b8c05 100644
--- a/torch/csrc/jit/passes/remove_mutation.cpp
+++ b/torch/csrc/jit/passes/remove_mutation.cpp
@@ -35,10 +35,39 @@
         !(v->node()->kind() == prim::Param);
   }
 
+  bool isSpecialMappedOp(Node* n) {
+    return n->matches("aten::zero_(Tensor(a!) self) -> Tensor(a!)") ||
+        n->matches(
+            "aten::fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!)");
+  }
+
+  Node* createSpecialMappedOp(Node* n) {
+    WithInsertPoint guard(n);
+    auto inputs = n->inputs();
+    Node* new_node;
+    if (n->matches(
+            "aten::fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!)")) {
+      new_node =
+          graph_->insert(aten::full_like, {inputs.at(0), inputs.at(1)})->node();
+    } else if (n->matches("aten::zero_(Tensor(a!) self) -> Tensor(a!)")) {
+      new_node = graph_->insert(aten::zeros_like, {n->inputs().at(0)})->node();
+    } else {
+      TORCH_INTERNAL_ASSERT(false);
+    }
+    new_node->copyMetadata(n);
+    new_node->output()->setType(n->output()->type());
+    return new_node;
+  }
+
   bool inplaceOpVariant(Node* n) {
     if (!n->kind().is_aten()) {
       return false;
     }
+
+    if (isSpecialMappedOp(n)) {
+      return true;
+    }
+
     auto name = n->schema().name();
     bool inplace_op = name.at(name.size() - 1) == '_';
     if (!inplace_op) {
@@ -181,15 +210,21 @@
         continue;
       }
 
-      auto schema_name = node->schema().name();
-      auto new_schema = schema_name.substr(0, schema_name.size() - 1);
-      auto new_node = graph_->create(Symbol::fromQualString(new_schema), 1);
-      new_node->copyMetadata(node);
-      new_node->insertBefore(node);
-      for (Value* input : node->inputs()) {
-        new_node->addInput(input);
+      Node* new_node;
+      if (isSpecialMappedOp(node)) {
+        new_node = createSpecialMappedOp(node);
+      } else {
+        auto schema_name = node->schema().name();
+        auto new_schema = schema_name.substr(0, schema_name.size() - 1);
+        new_node = graph_->create(Symbol::fromQualString(new_schema), 1);
+        new_node->copyMetadata(node);
+        new_node->insertBefore(node);
+        for (Value* input : node->inputs()) {
+          new_node->addInput(input);
+        }
+        new_node->output()->setType(node->output()->type());
       }
-      new_node->output()->setType(node->output()->type());
+
       mutated_value->replaceAllUsesAfterNodeWith(node, new_node->output());
       node->output()->replaceAllUsesWith(new_node->output());