[ONNX] Fix gather squeeze axis in constant folding (#63588) (#64379)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64379

* Fix gather squeeze axis in constant folding

* mypy

* fix indent

* address comments

Test Plan: Imported from OSS

Reviewed By: jansel

Differential Revision: D30919604

Pulled By: malfet

fbshipit-source-id: 90edb054491433a0da2fe82324ac7c12f1ef062b
diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py
index 3b7ff9f..653bea3 100644
--- a/test/onnx/test_pytorch_onnx_onnxruntime.py
+++ b/test/onnx/test_pytorch_onnx_onnxruntime.py
@@ -3571,7 +3571,6 @@
         self.run_test(GatherModel(), input=(input, indices))
 
     @disableScriptTest()  # RuntimeError: Python type cannot be used as a value
-    @skipIfUnsupportedMinOpsetVersion(11)
     def test_gather_constant_fold(self):
         class GatherModule(torch.nn.Module):
             def __init__(self):
@@ -3602,6 +3601,21 @@
         x = torch.randn(1, 3, 2)
         self.run_test(GatherModule(), (x,))
 
+        class GatherModule(torch.nn.Module):
+            def __init__(self):
+                super(GatherModule, self).__init__()
+                self.register_buffer("rb", torch.randn(1, 1, 3, 1, 1))
+
+            def forward(self, x):
+                x += self.rb[0]
+                return x
+
+        x = torch.randn(1, 3, 224, 224)
+        self.run_test(GatherModule(), (x,),
+                      dynamic_axes={"input": {0: "batch", 2: "height", 3: "width"},
+                                    "output": {0: "batch", 1: "class", 2: "height", 3: "width"}},
+                      input_names=['input'], output_names=['output'])
+
     @skipIfUnsupportedOpsetVersion([13])
     @skipIfUnsupportedMinOpsetVersion(9)
     def test_expand(self):
diff --git a/torch/csrc/jit/passes/onnx/constant_fold.cpp b/torch/csrc/jit/passes/onnx/constant_fold.cpp
index cce5a43..8d20c78 100644
--- a/torch/csrc/jit/passes/onnx/constant_fold.cpp
+++ b/torch/csrc/jit/passes/onnx/constant_fold.cpp
@@ -476,7 +476,7 @@
     // If rank of indices is 0, rank of output tensor should be
     // rank_of_input - 1.
     if (q < 1) {
-      updated_val = updated_val.squeeze();
+      updated_val = updated_val.squeeze(axis);
     }
     return c10::optional<at::Tensor>(updated_val);
   } else if (node->kind() == onnx::Range) {