Sort export w/ negative axes (#31971)
Summary:
Fixing export of Sort on negative axes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/31971
Reviewed By: hl475
Differential Revision: D19325874
Pulled By: houseroad
fbshipit-source-id: 18ab2bf39221970c8ab65a1355f5759f88faa54f
diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py
index 6bb10ff..51655ac 100644
--- a/test/onnx/test_pytorch_onnx_onnxruntime.py
+++ b/test/onnx/test_pytorch_onnx_onnxruntime.py
@@ -2038,30 +2038,26 @@
def test_sort(self):
class SortModel(torch.nn.Module):
- def __init__(self, dim):
- super(SortModel, self).__init__()
- self.dim = dim
-
def forward(self, x):
- return torch.sort(x, dim=self.dim, descending=True)
+ out = []
+ for i in range(-2, 2):
+ out.append(torch.sort(x, dim=i, descending=True))
+ return out
- dim = 1
x = torch.randn(3, 4)
- self.run_test(SortModel(dim), x)
+ self.run_test(SortModel(), x)
@skipIfUnsupportedMinOpsetVersion(11)
def test_sort_ascending(self):
class SortModel(torch.nn.Module):
- def __init__(self, dim):
- super(SortModel, self).__init__()
- self.dim = dim
-
def forward(self, x):
- return torch.sort(x, dim=self.dim, descending=False)
+ out = []
+ for i in range(-2, 2):
+ out.append(torch.sort(x, dim=i, descending=False))
+ return out
- dim = 1
x = torch.randn(3, 4)
- self.run_test(SortModel(dim), x)
+ self.run_test(SortModel(), x)
@skipIfUnsupportedMinOpsetVersion(9)
def test_masked_fill(self):
diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py
index 054ad42..61b2333 100644
--- a/torch/onnx/symbolic_helper.py
+++ b/torch/onnx/symbolic_helper.py
@@ -213,16 +213,13 @@
if out is not None:
_unimplemented("Sort", "Out parameter is not supported")
shape_ = g.op("Shape", input)
- axis = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64))
- start = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.int64))
- end = g.op("Constant", value_t=torch.tensor(dim + 1, dtype=torch.int64))
- slice_ = _slice_helper(g, shape_, axes=axis, starts=start, ends=end, steps=None, dynamic_slice=True)
+ dim_size_ = g.op("Gather", shape_, g.op("Constant", value_t=torch.tensor([dim], dtype=torch.int64)))
if _export_onnx_opset_version <= 10:
if not decending:
_unimplemented("Sort", "Ascending is not supported")
- return g.op("TopK", input, slice_, axis_i=dim, outputs=2)
+ return g.op("TopK", input, dim_size_, axis_i=dim, outputs=2)
else:
- return g.op("TopK", input, slice_, axis_i=dim, largest_i=decending, outputs=2)
+ return g.op("TopK", input, dim_size_, axis_i=dim, largest_i=decending, outputs=2)
def _topk_helper(g, input, k, dim, largest=True, sorted=False, out=None):