[JIT] handle specially mapped ops (#41503)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/41503
Fix for https://github.com/pytorch/pytorch/issues/41192
We can map fill_ and zero_ to their functional equivalents full_like and zeros_like
Test Plan: Imported from OSS
Reviewed By: jamesr66a
Differential Revision: D22629269
Pulled By: eellison
fbshipit-source-id: f1c62684dc55682c0b3845022e0461ec77d07179
diff --git a/test/jit/test_remove_mutation.py b/test/jit/test_remove_mutation.py
index 24da317..ef408e7 100644
--- a/test/jit/test_remove_mutation.py
+++ b/test/jit/test_remove_mutation.py
@@ -148,6 +148,33 @@
self.run_pass('remove_mutation', foo.graph)
FileCheck().check("aten::add_").run(foo.graph)
+ def test_special_mapped_op(self):
+ def test_successful():
+ x = torch.tensor([2, 2])
+ y = torch.tensor([2, 4])
+ x.zero_()
+ y.fill_(3)
+ return x, y
+
+ fn = torch.jit.script(test_successful)
+ graph = fn.graph
+ self.run_pass('remove_mutation', graph)
+ FileCheck().check_not("aten::zero_").check_not("aten::fill_").run(graph)
+ self.assertEqual(test_successful(), fn())
+
+ # full_like is not implemented for a tensor fill value
+
+ def test_unsuccessful():
+ x = torch.tensor([2, 2])
+ y = torch.tensor([2, 4])
+ x.fill_(y)
+ return x + x
+
+ fn = torch.jit.script(test_unsuccessful)
+ graph = fn.graph
+ self.run_pass('remove_mutation', graph)
+ FileCheck().check('aten::fill_').run(graph)
+
def test_lists_append(self):
def successful_remove():
return [i for i in range(5)] # noqa: C416
diff --git a/torch/csrc/jit/passes/remove_mutation.cpp b/torch/csrc/jit/passes/remove_mutation.cpp
index f1ac043..19b8c05 100644
--- a/torch/csrc/jit/passes/remove_mutation.cpp
+++ b/torch/csrc/jit/passes/remove_mutation.cpp
@@ -35,10 +35,39 @@
!(v->node()->kind() == prim::Param);
}
+ bool isSpecialMappedOp(Node* n) {
+ return n->matches("aten::zero_(Tensor(a!) self) -> Tensor(a!)") ||
+ n->matches(
+ "aten::fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!)");
+ }
+
+ Node* createSpecialMappedOp(Node* n) {
+ WithInsertPoint guard(n);
+ auto inputs = n->inputs();
+ Node* new_node;
+ if (n->matches(
+ "aten::fill_.Scalar(Tensor(a!) self, Scalar value) -> Tensor(a!)")) {
+ new_node =
+ graph_->insert(aten::full_like, {inputs.at(0), inputs.at(1)})->node();
+ } else if (n->matches("aten::zero_(Tensor(a!) self) -> Tensor(a!)")) {
+ new_node = graph_->insert(aten::zeros_like, {n->inputs().at(0)})->node();
+ } else {
+ TORCH_INTERNAL_ASSERT(false);
+ }
+ new_node->copyMetadata(n);
+ new_node->output()->setType(n->output()->type());
+ return new_node;
+ }
+
bool inplaceOpVariant(Node* n) {
if (!n->kind().is_aten()) {
return false;
}
+
+ if (isSpecialMappedOp(n)) {
+ return true;
+ }
+
auto name = n->schema().name();
bool inplace_op = name.at(name.size() - 1) == '_';
if (!inplace_op) {
@@ -181,15 +210,21 @@
continue;
}
- auto schema_name = node->schema().name();
- auto new_schema = schema_name.substr(0, schema_name.size() - 1);
- auto new_node = graph_->create(Symbol::fromQualString(new_schema), 1);
- new_node->copyMetadata(node);
- new_node->insertBefore(node);
- for (Value* input : node->inputs()) {
- new_node->addInput(input);
+ Node* new_node;
+ if (isSpecialMappedOp(node)) {
+ new_node = createSpecialMappedOp(node);
+ } else {
+ auto schema_name = node->schema().name();
+ auto new_schema = schema_name.substr(0, schema_name.size() - 1);
+ new_node = graph_->create(Symbol::fromQualString(new_schema), 1);
+ new_node->copyMetadata(node);
+ new_node->insertBefore(node);
+ for (Value* input : node->inputs()) {
+ new_node->addInput(input);
+ }
+ new_node->output()->setType(node->output()->type());
}
- new_node->output()->setType(node->output()->type());
+
mutated_value->replaceAllUsesAfterNodeWith(node, new_node->output());
node->output()->replaceAllUsesWith(new_node->output());