Graph break on differentiable boolean mask setitem (#102843)

Fixes https://github.com/pytorch/pytorch/issues/102841

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/102843
Approved by: https://github.com/voznesenskym
diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py
index 46e3858..61e9d49 100644
--- a/test/dynamo/test_repros.py
+++ b/test/dynamo/test_repros.py
@@ -1650,6 +1650,18 @@
         res = fn(y)
         self.assertTrue(same(ref, res))
 
+    def test_setitem_boolean_mask_diff(self):
+        def fn(x, b, y):
+            x = x.clone()
+            x[b] = y
+            return x
+
+        opt_fn = torch._dynamo.optimize("aot_eager")(fn)
+        x = torch.randn(4, requires_grad=True)
+        b = torch.tensor([True, False, True, False])
+        y = torch.randn(2, requires_grad=True)
+        opt_fn(x, b, y)
+
     def test_torch_tensor_ops(self):
         def fn(x):
             return torch.Tensor.abs_(x)
diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py
index 65c730d..76e4990 100644
--- a/torch/_dynamo/variables/tensor.py
+++ b/torch/_dynamo/variables/tensor.py
@@ -424,6 +424,20 @@
         elif name == "__len__":
             return self.call_method(tx, "size", [ConstantVariable(0, **options)], {})
         elif name == "__setitem__":
+            if (
+                not config.capture_dynamic_output_shape_ops
+                # NB: the bool tensor and the requires_grad tensor are
+                # never the same tensor!
+                and any(
+                    isinstance(a, TensorVariable)
+                    and a.dtype in (torch.bool, torch.int8)
+                    for a in args
+                )
+                and any(isinstance(a, TensorVariable) and a.requires_grad for a in args)
+            ):
+                unimplemented(
+                    "boolean masking setitem backwards requires dynamic shapes"
+                )
             tx.output.guards.update(options["guards"])
             tx.output.create_proxy(
                 "call_function",