Fix graph break on boolean mask better (#103052)

Previously I accidentally thought setitem takes each argument as a
list.  But if you write x[:, b] that actually is passed in as a tuple.
Try harder.

Signed-off-by: Edward Z. Yang <ezyang@meta.com>

Pull Request resolved: https://github.com/pytorch/pytorch/pull/103052
Approved by: https://github.com/desertfire
diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py
index 0d9b21e..afd8856 100644
--- a/test/dynamo/test_repros.py
+++ b/test/dynamo/test_repros.py
@@ -1662,6 +1662,18 @@
         y = torch.randn(2, requires_grad=True)
         opt_fn(x, b, y)
 
+    def test_setitem_tuple_boolean_mask_diff(self):
+        def fn(x, b, y):
+            x = x.clone()
+            x[:, b] = y
+            return x
+
+        opt_fn = torch._dynamo.optimize("aot_eager")(fn)
+        x = torch.randn(8, 4, requires_grad=True)
+        b = torch.tensor([True, False, True, False])
+        y = torch.randn(2, requires_grad=True)
+        opt_fn(x, b, y)
+
     def test_torch_tensor_ops(self):
         def fn(x):
             return torch.Tensor.abs_(x)
diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py
index 429fe2a..663fff0 100644
--- a/torch/_dynamo/variables/tensor.py
+++ b/torch/_dynamo/variables/tensor.py
@@ -421,16 +421,21 @@
         elif name == "__len__":
             return self.call_method(tx, "size", [ConstantVariable(0, **options)], {})
         elif name == "__setitem__":
+            key, value = args
+
+            def has_bool_key(v):
+                if isinstance(v, TensorVariable):
+                    return v.dtype in (torch.bool, torch.int8)
+                elif isinstance(v, TupleVariable):
+                    return any(has_bool_key(item) for item in v.items)
+                else:
+                    return False
+
             if (
                 not config.capture_dynamic_output_shape_ops
-                # NB: the bool tensor and the requires_grad tensor are
-                # never the same tensor!
-                and any(
-                    isinstance(a, TensorVariable)
-                    and a.dtype in (torch.bool, torch.int8)
-                    for a in args
-                )
-                and any(isinstance(a, TensorVariable) and a.requires_grad for a in args)
+                and has_bool_key(key)
+                and isinstance(value, TensorVariable)
+                and value.requires_grad
             ):
                 unimplemented(
                     "boolean masking setitem backwards requires dynamic shapes"