Never CSE aten.empty in the partitioner (#134703)

aten.empty is almost always fusible into its consumer, so we never CSE
it. This fixes a bug that looks like the following:

```py
@torch.library.custom_op("_reinplacing::sin_cos", mutates_args={"out_sin", "out_cos"})
def sin_cos(x: torch.Tensor, out_sin: torch.Tensor, out_cos: torch.Tensor) -> None:
    out_sin.copy_(x.sin())
    out_cos.copy_(x.cos())

@torch.compile
def f(x):
    out0 = torch.empty_like(x)
    out1 = torch.empty_like(x)
    sin_cos(x, out0, out1)
    return x.clone(), out0, out1

x = torch.randn(3, requires_grad=True)
f(x)
```

- cse would de-duplicate the empty nodes
- reinplacing would add an additional clone (because it can't write to
  both tensors at the same time)
- the clone lowers into a new buffer + a copy_ kernel
- the copy_ kernel is unnecessary because "empty" is special - all reinplacing needed was an additional
  buffer, it doesn't matter what the values are.

We could attempt to fix this on the reinplacing side but this seemed
better as a partitioner heuristic and the reinplacing fix is a bit more
tricky (we'd need to identify that the op never reads from the empty
node).

Test Plan:
- new test (the old number was 27, the new number is 21, so this PR
  helped).

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134703
Approved by: https://github.com/yf225
ghstack dependencies: #134466, #134490, #134491
diff --git a/test/inductor/test_perf.py b/test/inductor/test_perf.py
index 14694e5..302e246 100644
--- a/test/inductor/test_perf.py
+++ b/test/inductor/test_perf.py
@@ -906,6 +906,26 @@
         self.assertExpectedInline(count_numel_train(f, x), """9""")
 
     @requires_cuda
+    def test_inplace_custom_op_training_two_mutated_inputs(self):
+        @torch.library.custom_op(
+            "_reinplacing::sin_cos", mutates_args={"out_sin", "out_cos"}
+        )
+        def sin_cos(
+            x: torch.Tensor, out_sin: torch.Tensor, out_cos: torch.Tensor
+        ) -> None:
+            out_sin.copy_(x.sin())
+            out_cos.copy_(x.cos())
+
+        def f(x):
+            out0 = torch.empty_like(x)
+            out1 = torch.empty_like(x)
+            sin_cos(x, out0, out1)
+            return x.clone(), out0, out1
+
+        x = T(3, grad=True)
+        self.assertExpectedInline(count_numel(f, x), """21""")
+
+    @requires_cuda
     def test_inplace_custom_op_training(self):
         @torch.library.custom_op("_reinplacing::sin", mutates_args={"result"})
         def sin(x: torch.Tensor, result: torch.Tensor) -> None:
diff --git a/torch/_functorch/compile_utils.py b/torch/_functorch/compile_utils.py
index 77aeeeb..ac7e0d5 100644
--- a/torch/_functorch/compile_utils.py
+++ b/torch/_functorch/compile_utils.py
@@ -58,6 +58,10 @@
             or n.op == "output"
             or n.op == "get_attr"
             or get_aten_target(n) in rand_ops
+            # aten.empty is non-deterministic, so don't CSE it.
+            # Also, aten.empty is almost always fusible into its consumer,
+            # so it's not worth CSEing.
+            or get_aten_target(n) is aten.empty
         ):
             new_node = new_graph.node_copy(n, lambda x: env[x])
             env[n] = new_node