Fix graph break on boolean mask better (#103052)
Previously I accidentally thought setitem takes each argument as a
list. But if you write x[:, b] that actually is passed in as a tuple.
Try harder.
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/103052
Approved by: https://github.com/desertfire
diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py
index 0d9b21e..afd8856 100644
--- a/test/dynamo/test_repros.py
+++ b/test/dynamo/test_repros.py
@@ -1662,6 +1662,18 @@
y = torch.randn(2, requires_grad=True)
opt_fn(x, b, y)
+ def test_setitem_tuple_boolean_mask_diff(self):
+ def fn(x, b, y):
+ x = x.clone()
+ x[:, b] = y
+ return x
+
+ opt_fn = torch._dynamo.optimize("aot_eager")(fn)
+ x = torch.randn(8, 4, requires_grad=True)
+ b = torch.tensor([True, False, True, False])
+ y = torch.randn(2, requires_grad=True)
+ opt_fn(x, b, y)
+
def test_torch_tensor_ops(self):
def fn(x):
return torch.Tensor.abs_(x)
diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py
index 429fe2a..663fff0 100644
--- a/torch/_dynamo/variables/tensor.py
+++ b/torch/_dynamo/variables/tensor.py
@@ -421,16 +421,21 @@
elif name == "__len__":
return self.call_method(tx, "size", [ConstantVariable(0, **options)], {})
elif name == "__setitem__":
+ key, value = args
+
+ def has_bool_key(v):
+ if isinstance(v, TensorVariable):
+ return v.dtype in (torch.bool, torch.int8)
+ elif isinstance(v, TupleVariable):
+ return any(has_bool_key(item) for item in v.items)
+ else:
+ return False
+
if (
not config.capture_dynamic_output_shape_ops
- # NB: the bool tensor and the requires_grad tensor are
- # never the same tensor!
- and any(
- isinstance(a, TensorVariable)
- and a.dtype in (torch.bool, torch.int8)
- for a in args
- )
- and any(isinstance(a, TensorVariable) and a.requires_grad for a in args)
+ and has_bool_key(key)
+ and isinstance(value, TensorVariable)
+ and value.requires_grad
):
unimplemented(
"boolean masking setitem backwards requires dynamic shapes"