[ONNX] Improve index_put symbolic to handle singular Bool updates (#53690) (#54863)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/54863
Adds support for cases where the updates to the index_put node is a single Bool value, such as the case shown below
```
mask[indices] = True
```
Fixes #53507
Test Plan: Imported from OSS
Reviewed By: nikithamalgifb
Differential Revision: D27408977
Pulled By: SplitInfinity
fbshipit-source-id: bcfb55b50ce76b3d4913ffbc16cdef1f98cb7a84
diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py
index 6cb313c..e832ff7 100644
--- a/test/onnx/test_pytorch_onnx_onnxruntime.py
+++ b/test/onnx/test_pytorch_onnx_onnxruntime.py
@@ -1755,6 +1755,26 @@
self.run_test(IndexPutModel(), (x, ind, update))
@skipIfUnsupportedMinOpsetVersion(11)
+ def test_index_put_singular(self):
+ class IndexPutBoolModel(torch.nn.Module):
+ def forward(self, mask, indices):
+ mask[indices] = True
+ return mask
+
+ mask = torch.zeros(100, dtype=torch.bool)
+ indices = (torch.rand(25) * mask.shape[0]).to(torch.int64)
+ self.run_test(IndexPutBoolModel(), (mask, indices))
+
+ class IndexPutFloatModel(torch.nn.Module):
+ def forward(self, mask, indices):
+ mask[indices] = torch.tensor(5.5)
+ return mask
+
+ mask = torch.rand(100, dtype=torch.float)
+ indices = (torch.rand(50) * mask.shape[0]).to(torch.int64)
+ self.run_test(IndexPutFloatModel(), (mask, indices))
+
+ @skipIfUnsupportedMinOpsetVersion(11)
def test_index_put_accumulate(self):
class IndexPutModel(torch.nn.Module):
def forward(self, x, ind, update):
diff --git a/torch/onnx/symbolic_opset11.py b/torch/onnx/symbolic_opset11.py
index 2e0c1bd..a129e90 100644
--- a/torch/onnx/symbolic_opset11.py
+++ b/torch/onnx/symbolic_opset11.py
@@ -136,6 +136,10 @@
sub_data_shape = sym_help._slice_helper(
g, g.op("Shape", self), axes=[0], starts=[len(indices_list)], ends=[maxsize])
values_shape = g.op("Concat", broadcast_index_shape, sub_data_shape, axis_i=0)
+ # Check if values is a singular value and expand accordingly
+ rank = sym_help._get_tensor_rank(values)
+ if rank is not None and rank == 0:
+ values = expand(g, values, values_shape, None)
values = g.op("Reshape", values, values_shape)
if accumulate: