index put device error checking (#113729)

Fix for https://github.com/pytorch/pytorch/issues/101371

Pull Request resolved: https://github.com/pytorch/pytorch/pull/113729
Approved by: https://github.com/bdhirsh
diff --git a/test/test_fake_tensor.py b/test/test_fake_tensor.py
index 65b6989..0c87878 100644
--- a/test/test_fake_tensor.py
+++ b/test/test_fake_tensor.py
@@ -267,6 +267,25 @@
         self.assertTrue(isinstance(fake_x.grad, FakeTensor))
 
     @unittest.skipIf(not RUN_CUDA, "requires cuda")
+    def test_index_put_error(self):
+        mode = FakeTensorMode()
+        for context in [contextlib.nullcontext, lambda: mode]:
+            with context():
+                y = torch.randn(2, 2, 3)
+                x = torch.randn(2, 2, 3).to('cuda')
+                with self.assertRaises(RuntimeError):
+                    x[[1, 1]] = y
+
+                with self.assertRaises(RuntimeError):
+                    torch.ops.aten.index_put(x, torch.tensor([1, 1], device="cuda"), y)
+
+                # no error
+                torch.ops.aten.index_put(x, torch.tensor([1, 1], device="cuda"), torch.tensor(5.))
+                torch.ops.aten.index_put_(x, torch.tensor([1, 1], device="cuda"), torch.tensor(5.))
+
+
+
+    @unittest.skipIf(not RUN_CUDA, "requires cuda")
     def test_like_constructor(self):
         with FakeTensorMode():
             x = torch.rand([4, 4])
diff --git a/torch/_subclasses/fake_tensor.py b/torch/_subclasses/fake_tensor.py
index d4dbdd8..a8765f5 100644
--- a/torch/_subclasses/fake_tensor.py
+++ b/torch/_subclasses/fake_tensor.py
@@ -702,7 +702,6 @@
 
 
 # takes in multiple-devices, dont default to default device handling
-@register_op_impl(aten.index_put.default)
 @register_op_impl(aten._unsafe_index_put.default)
 @register_op_impl(aten.copy.default)
 @register_op_impl(aten.copy_.default)
@@ -712,7 +711,6 @@
 
 
 # same with multi_device_op_default, but return the input
-@register_op_impl(aten.index_put_.default)
 @register_op_impl(aten.copy.out)
 @register_op_impl(aten.slice_scatter.out)
 def multi_device_op_out(fake_mode, func, *args, **kwargs):
@@ -726,6 +724,27 @@
     return new_kwargs["input"]
 
 
+@register_op_impl(aten.index_put.default)
+@register_op_impl(aten.index_put_.default)
+def index_put_impl(fake_mode, func, *args, **kwargs):
+    _, new_kwargs = normalize_function(
+        func, args=args, kwargs=kwargs, normalize_to_only_use_kwargs=True
+    )
+
+    values = new_kwargs["values"]
+    self_device = new_kwargs["input"].fake_device
+    torch._check(
+        self_device == values.fake_device or (values.ndim == 0 and values.numel() == 1),
+        lambda: f"Mismatching {func} device between self ({self_device}) and values ({values.device})",
+    )
+
+    out = run_and_return_new_tensor_of_input_device(fake_mode, func, args, kwargs)
+    if func is aten.index_put_.default:
+        return new_kwargs["input"]
+    else:
+        return out
+
+
 @register_op_impl(lambda fn: fn in _device_not_kwarg_ops)
 def nyi(fake_mode, func, *args, **kwargs):
     assert func not in _device_not_kwarg_ops, f"NYI: {func}"