[ONNX] Improve index_put symbolic to handle singular Bool updates (#53690) (#54863)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54863

Adds support for cases where the updates to the index_put node is a single Bool value, such as the case shown below

```
mask[indices] = True
```

Fixes #53507

Test Plan: Imported from OSS

Reviewed By: nikithamalgifb

Differential Revision: D27408977

Pulled By: SplitInfinity

fbshipit-source-id: bcfb55b50ce76b3d4913ffbc16cdef1f98cb7a84
diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py
index 6cb313c..e832ff7 100644
--- a/test/onnx/test_pytorch_onnx_onnxruntime.py
+++ b/test/onnx/test_pytorch_onnx_onnxruntime.py
@@ -1755,6 +1755,26 @@
         self.run_test(IndexPutModel(), (x, ind, update))
 
     @skipIfUnsupportedMinOpsetVersion(11)
+    def test_index_put_singular(self):
+        class IndexPutBoolModel(torch.nn.Module):
+            def forward(self, mask, indices):
+                mask[indices] = True
+                return mask
+
+        mask = torch.zeros(100, dtype=torch.bool)
+        indices = (torch.rand(25) * mask.shape[0]).to(torch.int64)
+        self.run_test(IndexPutBoolModel(), (mask, indices))
+
+        class IndexPutFloatModel(torch.nn.Module):
+            def forward(self, mask, indices):
+                mask[indices] = torch.tensor(5.5)
+                return mask
+
+        mask = torch.rand(100, dtype=torch.float)
+        indices = (torch.rand(50) * mask.shape[0]).to(torch.int64)
+        self.run_test(IndexPutFloatModel(), (mask, indices))
+
+    @skipIfUnsupportedMinOpsetVersion(11)
     def test_index_put_accumulate(self):
         class IndexPutModel(torch.nn.Module):
             def forward(self, x, ind, update):
diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py
index 2e0c1bd..a129e90 100644
--- a/torch/onnx/symbolic_opset11.py
+++ b/torch/onnx/symbolic_opset11.py
@@ -136,6 +136,10 @@
     sub_data_shape = sym_help._slice_helper(
         g, g.op("Shape", self), axes=[0], starts=[len(indices_list)], ends=[maxsize])
     values_shape = g.op("Concat", broadcast_index_shape, sub_data_shape, axis_i=0)
+    # Check if values is a singular value and expand accordingly
+    rank = sym_help._get_tensor_rank(values)
+    if rank is not None and rank == 0:
+        values = expand(g, values, values_shape, None)
     values = g.op("Reshape", values, values_shape)
 
     if accumulate: