Sort export w/ negative axes (#31971)

Summary:
Fixing export of Sort on negative axes
Pull Request resolved: https://github.com/pytorch/pytorch/pull/31971

Reviewed By: hl475

Differential Revision: D19325874

Pulled By: houseroad

fbshipit-source-id: 18ab2bf39221970c8ab65a1355f5759f88faa54f
diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py
index 6bb10ff..51655ac 100644
--- a/test/onnx/test_pytorch_onnx_onnxruntime.py
+++ b/test/onnx/test_pytorch_onnx_onnxruntime.py
@@ -2038,30 +2038,26 @@
 
     def test_sort(self):
         class SortModel(torch.nn.Module):
-            def __init__(self, dim):
-                super(SortModel, self).__init__()
-                self.dim = dim
-
             def forward(self, x):
-                return torch.sort(x, dim=self.dim, descending=True)
+                out = []
+                for i in range(-2, 2):
+                    out.append(torch.sort(x, dim=i, descending=True))
+                return out
 
-        dim = 1
         x = torch.randn(3, 4)
-        self.run_test(SortModel(dim), x)
+        self.run_test(SortModel(), x)
 
     @skipIfUnsupportedMinOpsetVersion(11)
     def test_sort_ascending(self):
         class SortModel(torch.nn.Module):
-            def __init__(self, dim):
-                super(SortModel, self).__init__()
-                self.dim = dim
-
             def forward(self, x):
-                return torch.sort(x, dim=self.dim, descending=False)
+                out = []
+                for i in range(-2, 2):
+                    out.append(torch.sort(x, dim=i, descending=False))
+                return out
 
-        dim = 1
         x = torch.randn(3, 4)
-        self.run_test(SortModel(dim), x)
+        self.run_test(SortModel(), x)
 
     @skipIfUnsupportedMinOpsetVersion(9)
     def test_masked_fill(self):
diff --git a/torch/onnx/symbolic_helper.py b/torch/onnx/symbolic_helper.py
index 054ad42..61b2333 100644
--- a/torch/onnx/symbolic_helper.py
+++ b/torch/onnx/symbolic_helper.py
@@ -213,16 +213,13 @@
     if out is not None:
         _unimplemented("Sort", "Out parameter is not supported")
     shape_ = g.op("Shape", input)
-    axis = g.op("Constant", value_t=torch.tensor(0, dtype=torch.int64))
-    start = g.op("Constant", value_t=torch.tensor(dim, dtype=torch.int64))
-    end = g.op("Constant", value_t=torch.tensor(dim + 1, dtype=torch.int64))
-    slice_ = _slice_helper(g, shape_, axes=axis, starts=start, ends=end, steps=None, dynamic_slice=True)
+    dim_size_ = g.op("Gather", shape_, g.op("Constant", value_t=torch.tensor([dim], dtype=torch.int64)))
     if _export_onnx_opset_version <= 10:
         if not decending:
             _unimplemented("Sort", "Ascending is not supported")
-        return g.op("TopK", input, slice_, axis_i=dim, outputs=2)
+        return g.op("TopK", input, dim_size_, axis_i=dim, outputs=2)
     else:
-        return g.op("TopK", input, slice_, axis_i=dim, largest_i=decending, outputs=2)
+        return g.op("TopK", input, dim_size_, axis_i=dim, largest_i=decending, outputs=2)
 
 
 def _topk_helper(g, input, k, dim, largest=True, sorted=False, out=None):