Graph break on differentiable boolean mask setitem (#102843)
Fixes https://github.com/pytorch/pytorch/issues/102841
Signed-off-by: Edward Z. Yang <ezyang@meta.com>
Pull Request resolved: https://github.com/pytorch/pytorch/pull/102843
Approved by: https://github.com/voznesenskym
diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py
index 46e3858..61e9d49 100644
--- a/test/dynamo/test_repros.py
+++ b/test/dynamo/test_repros.py
@@ -1650,6 +1650,18 @@
res = fn(y)
self.assertTrue(same(ref, res))
+ def test_setitem_boolean_mask_diff(self):
+ def fn(x, b, y):
+ x = x.clone()
+ x[b] = y
+ return x
+
+ opt_fn = torch._dynamo.optimize("aot_eager")(fn)
+ x = torch.randn(4, requires_grad=True)
+ b = torch.tensor([True, False, True, False])
+ y = torch.randn(2, requires_grad=True)
+ opt_fn(x, b, y)
+
def test_torch_tensor_ops(self):
def fn(x):
return torch.Tensor.abs_(x)
diff --git a/torch/_dynamo/variables/tensor.py b/torch/_dynamo/variables/tensor.py
index 65c730d..76e4990 100644
--- a/torch/_dynamo/variables/tensor.py
+++ b/torch/_dynamo/variables/tensor.py
@@ -424,6 +424,20 @@
elif name == "__len__":
return self.call_method(tx, "size", [ConstantVariable(0, **options)], {})
elif name == "__setitem__":
+ if (
+ not config.capture_dynamic_output_shape_ops
+ # NB: the bool tensor and the requires_grad tensor are
+ # never the same tensor!
+ and any(
+ isinstance(a, TensorVariable)
+ and a.dtype in (torch.bool, torch.int8)
+ for a in args
+ )
+ and any(isinstance(a, TensorVariable) and a.requires_grad for a in args)
+ ):
+ unimplemented(
+ "boolean masking setitem backwards requires dynamic shapes"
+ )
tx.output.guards.update(options["guards"])
tx.output.create_proxy(
"call_function",