Never CSE aten.empty in the partitioner (#134703)
aten.empty is almost always fusible into its consumer, so we never CSE
it. This fixes a bug that looks like the following:
```py
@torch.library.custom_op("_reinplacing::sin_cos", mutates_args={"out_sin", "out_cos"})
def sin_cos(x: torch.Tensor, out_sin: torch.Tensor, out_cos: torch.Tensor) -> None:
out_sin.copy_(x.sin())
out_cos.copy_(x.cos())
@torch.compile
def f(x):
out0 = torch.empty_like(x)
out1 = torch.empty_like(x)
sin_cos(x, out0, out1)
return x.clone(), out0, out1
x = torch.randn(3, requires_grad=True)
f(x)
```
- cse would de-duplicate the empty nodes
- reinplacing would add an additional clone (because it can't write to
both tensors at the same time)
- the clone lowers into a new buffer + a copy_ kernel
- the copy_ kernel is unnecessary because "empty" is special - all reinplacing needed was an additional
buffer, it doesn't matter what the values are.
We could attempt to fix this on the reinplacing side but this seemed
better as a partitioner heuristic and the reinplacing fix is a bit more
tricky (we'd need to identify that the op never reads from the empty
node).
Test Plan:
- new test (the old number was 27, the new number is 21, so this PR
helped).
Pull Request resolved: https://github.com/pytorch/pytorch/pull/134703
Approved by: https://github.com/yf225
ghstack dependencies: #134466, #134490, #134491
diff --git a/test/inductor/test_perf.py b/test/inductor/test_perf.py
index 14694e5..302e246 100644
--- a/test/inductor/test_perf.py
+++ b/test/inductor/test_perf.py
@@ -906,6 +906,26 @@
self.assertExpectedInline(count_numel_train(f, x), """9""")
@requires_cuda
+ def test_inplace_custom_op_training_two_mutated_inputs(self):
+ @torch.library.custom_op(
+ "_reinplacing::sin_cos", mutates_args={"out_sin", "out_cos"}
+ )
+ def sin_cos(
+ x: torch.Tensor, out_sin: torch.Tensor, out_cos: torch.Tensor
+ ) -> None:
+ out_sin.copy_(x.sin())
+ out_cos.copy_(x.cos())
+
+ def f(x):
+ out0 = torch.empty_like(x)
+ out1 = torch.empty_like(x)
+ sin_cos(x, out0, out1)
+ return x.clone(), out0, out1
+
+ x = T(3, grad=True)
+ self.assertExpectedInline(count_numel(f, x), """21""")
+
+ @requires_cuda
def test_inplace_custom_op_training(self):
@torch.library.custom_op("_reinplacing::sin", mutates_args={"result"})
def sin(x: torch.Tensor, result: torch.Tensor) -> None:
diff --git a/torch/_functorch/compile_utils.py b/torch/_functorch/compile_utils.py
index 77aeeeb..ac7e0d5 100644
--- a/torch/_functorch/compile_utils.py
+++ b/torch/_functorch/compile_utils.py
@@ -58,6 +58,10 @@
or n.op == "output"
or n.op == "get_attr"
or get_aten_target(n) in rand_ops
+ # aten.empty is non-deterministic, so don't CSE it.
+ # Also, aten.empty is almost always fusible into its consumer,
+ # so it's not worth CSEing.
+ or get_aten_target(n) is aten.empty
):
new_node = new_graph.node_copy(n, lambda x: env[x])
env[n] = new_node