index put device error checking (#113729)
Fix for https://github.com/pytorch/pytorch/issues/101371
Pull Request resolved: https://github.com/pytorch/pytorch/pull/113729
Approved by: https://github.com/bdhirsh
diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py
index 65b6989..0c87878 100644
--- a/test/test_fake_tensor.py
+++ b/test/test_fake_tensor.py
@@ -267,6 +267,25 @@
self.assertTrue(isinstance(fake_x.grad, FakeTensor))
@unittest.skipIf(not RUN_CUDA, "requires cuda")
+ def test_index_put_error(self):
+ mode = FakeTensorMode()
+ for context in [contextlib.nullcontext, lambda: mode]:
+ with context():
+ y = torch.randn(2, 2, 3)
+ x = torch.randn(2, 2, 3).to('cuda')
+ with self.assertRaises(RuntimeError):
+ x[[1, 1]] = y
+
+ with self.assertRaises(RuntimeError):
+ torch.ops.aten.index_put(x, torch.tensor([1, 1], device="cuda"), y)
+
+ # no error
+ torch.ops.aten.index_put(x, torch.tensor([1, 1], device="cuda"), torch.tensor(5.))
+ torch.ops.aten.index_put_(x, torch.tensor([1, 1], device="cuda"), torch.tensor(5.))
+
+
+
+ @unittest.skipIf(not RUN_CUDA, "requires cuda")
def test_like_constructor(self):
with FakeTensorMode():
x = torch.rand([4, 4])
diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py
index d4dbdd8..a8765f5 100644
--- a/torch/_subclasses/fake_tensor.py
+++ b/torch/_subclasses/fake_tensor.py
@@ -702,7 +702,6 @@
# takes in multiple-devices, dont default to default device handling
-@register_op_impl(aten.index_put.default)
@register_op_impl(aten._unsafe_index_put.default)
@register_op_impl(aten.copy.default)
@register_op_impl(aten.copy_.default)
@@ -712,7 +711,6 @@
# same with multi_device_op_default, but return the input
-@register_op_impl(aten.index_put_.default)
@register_op_impl(aten.copy.out)
@register_op_impl(aten.slice_scatter.out)
def multi_device_op_out(fake_mode, func, *args, **kwargs):
@@ -726,6 +724,27 @@
return new_kwargs["input"]
+@register_op_impl(aten.index_put.default)
+@register_op_impl(aten.index_put_.default)
+def index_put_impl(fake_mode, func, *args, **kwargs):
+ _, new_kwargs = normalize_function(
+ func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+ )
+
+ values = new_kwargs["values"]
+ self_device = new_kwargs["input"].fake_device
+ torch._check(
+ self_device == values.fake_device or (values.ndim == 0 and values.numel() == 1),
+ lambda: f"Mismatching {func} device between self ({self_device}) and values ({values.device})",
+ )
+
+ out = run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs)
+ if func is aten.index_put_.default:
+ return new_kwargs["input"]
+ else:
+ return out
+
+
@register_op_impl(lambda fn: fn in _device_not_kwarg_ops)
def nyi(fake_mode, func, *args, **kwargs):
assert func not in _device_not_kwarg_ops, f"NYI: {func}"