Add `GRAPH_UPDATE` for `x.size()` in Peephole Optimize (#34865)
Summary:
Fix https://github.com/pytorch/pytorch/issues/31820
Pull Request resolved: https://github.com/pytorch/pytorch/pull/34865
Reviewed By: jamesr66a
Differential Revision: D20772078
Pulled By: suo
fbshipit-source-id: cddf870e23983cc42da898edf3f98897353b2abe
diff --git a/torch/csrc/jit/passes/peephole.cpp b/torch/csrc/jit/passes/peephole.cpp
index fbbb35b..ae9b308 100644
--- a/torch/csrc/jit/passes/peephole.cpp
+++ b/torch/csrc/jit/passes/peephole.cpp
@@ -176,7 +176,7 @@
"aten::_grad_sum_to_size(Tensor(a) self, int[]? size) -> Tensor(a)")) {
if (node->input(1)->mustBeNone()) {
GRAPH_UPDATE(
- *node,
+ getHeader(node),
" (x._grad_sum_to_size(x, None) == x) is replaced with ",
node->input(0)->debugName());
node->output()->replaceAllUsesWith(node->input(0));
@@ -188,7 +188,7 @@
"aten::_grad_sum_to_size(Tensor(a) self, int[]? size) -> Tensor(a)") &&
u.user->input(1)->type()->isSubtypeOf(ListType::ofInts())) {
GRAPH_UPDATE(
- *node,
+ getHeader(node),
" (x._grad_sum_to_size(y)._grad_sum_to_size(z) == x._grad_sum_to_size(z)) is replaced with ",
node->inputs().at(0)->debugName());
u.user->replaceInput(0, node->inputs().at(0));
@@ -208,7 +208,7 @@
if (expanded_sizes.has_value() && input_type_sizes &&
expanded_sizes->vec() == *input_type_sizes) {
GRAPH_UPDATE(
- *node,
+ getHeader(node),
" (x.expand(x.size()) == x) is replaced with ",
node->namedInput(attr::self)->debugName());
node->output()->replaceAllUsesWith(node->namedInput(attr::self));
@@ -220,7 +220,7 @@
Node* input_node = node->input()->node();
if (input_node->matches("aten::t(Tensor self) -> Tensor")) {
GRAPH_UPDATE(
- *node,
+ getHeader(node),
" (x.t().t() == x) is replaced with ",
input_node->input()->debugName());
node->output()->replaceAllUsesWith(input_node->input());
@@ -234,7 +234,7 @@
if (mustBeEqual(self_type->scalarType(), other_type->scalarType()) &&
mustBeEqual(self_type->device(), other_type->device())) {
GRAPH_UPDATE(
- *node,
+ getHeader(node),
" (x.type_as(y) == x) is replaced with ",
node->input(0)->debugName());
node->output()->replaceAllUsesWith(node->input(0));
@@ -248,7 +248,7 @@
Node* input_node = node->input()->node();
if (input_node->kind() == prim::NumToTensor) {
GRAPH_UPDATE(
- *node,
+ getHeader(node),
" (x.NumToTensor().TensorToNum() == x.NumToTensor()) is replaced with ",
node->input()->debugName());
node->output()->replaceAllUsesWith(input_node->input());
@@ -257,6 +257,10 @@
} else if (node->matches("aten::size(Tensor self) -> int[]")) {
if (auto ptt = node->input()->type()->cast<TensorType>()) {
if (auto sizes = ptt->sizes().concrete_sizes()) {
+ GRAPH_UPDATE(
+ getHeader(node),
+ " (x.size()) is replaced with ",
+ node->input()->debugName());
WithInsertPoint guard(node);
IValue ival(sizes);
auto const_sizes_val = node->owningGraph()->insertConstant(ival);
@@ -303,7 +307,8 @@
WithInsertPoint guard(node);
auto output = node->owningGraph()->insertConstant(
node->kind() == aten::__isnot__);
- GRAPH_UPDATE("Folding ", *node, " to ", output->debugName());
+ GRAPH_UPDATE(
+ "Folding ", getHeader(node), " to ", output->debugName());
node->output()->replaceAllUsesWith(output);
changed_ = true;
}
@@ -316,7 +321,7 @@
if (input->mustNotBeNone()) {
GRAPH_UPDATE(
"Unwrapping ",
- *node,
+ getHeader(node),
" as ",
node->input(),
" can't be optional");
@@ -330,7 +335,9 @@
auto output_type = unshapedType(node->output()->type());
if (input_type->isSubtypeOf(output_type)) {
GRAPH_UPDATE(
- "Removing ", *node, " as input type subtypes output type");
+ "Removing ",
+ getHeader(node),
+ " as input type subtypes output type");
node->output()->replaceAllUsesWith(node->input());
}
} else if (node->matches("prim::dtype(Tensor a) -> int")) {
@@ -341,7 +348,7 @@
static_cast<int64_t>(*ptt->scalarType()));
GRAPH_UPDATE(
"Replacing ",
- *node,
+ getHeader(node),
" with a type constant ",
output->debugName());
node->output()->replaceAllUsesWith(output);
@@ -354,7 +361,7 @@
auto output = node->owningGraph()->insertConstant(*ptt->device());
GRAPH_UPDATE(
"Replacing ",
- *node,
+ getHeader(node),
" with a device constant ",
output->debugName());
node->output()->replaceAllUsesWith(output);
@@ -368,7 +375,7 @@
node->owningGraph()->insertConstant(static_cast<int64_t>(*dim));
GRAPH_UPDATE(
"Replacing ",
- *node,
+ getHeader(node),
" with a \"dim\" constant ",
output->debugName());
node->output()->replaceAllUsesWith(output);
@@ -382,7 +389,7 @@
node->owningGraph()->insertConstant((*ptt->device()).is_cuda());
GRAPH_UPDATE(
"Replacing ",
- *node,
+ getHeader(node),
" with a is_cuda constant ",
output->debugName());
node->output()->replaceAllUsesWith(output);
@@ -430,7 +437,7 @@
return;
}
GRAPH_UPDATE(
- *node,
+ getHeader(node),
" (x + 0 == x - 0 == x) is replaced with ",
node->input(0)->debugName());
node->output()->replaceAllUsesWith(node->input(0));
@@ -449,7 +456,7 @@
return;
}
GRAPH_UPDATE(
- *node,
+ getHeader(node),
" (x * 1 == x / 1 == x) is replaced with ",
node->input(0)->debugName());
node->output()->replaceAllUsesWith(node->input(0));