[ONNX] Fix gather squeeze axis in constant folding (#63588) (#64379)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64379
* Fix gather squeeze axis in constant folding
* mypy
* fix indent
* address comments
Test Plan: Imported from OSS
Reviewed By: jansel
Differential Revision: D30919604
Pulled By: malfet
fbshipit-source-id: 90edb054491433a0da2fe82324ac7c12f1ef062b
diff --git a/test/onnx/test_pytorch_onnx_onnxruntime.py b/test/onnx/test_pytorch_onnx_onnxruntime.py
index 3b7ff9f..653bea3 100644
--- a/test/onnx/test_pytorch_onnx_onnxruntime.py
+++ b/test/onnx/test_pytorch_onnx_onnxruntime.py
@@ -3571,7 +3571,6 @@
self.run_test(GatherModel(), input=(input, indices))
@disableScriptTest() # RuntimeError: Python type cannot be used as a value
- @skipIfUnsupportedMinOpsetVersion(11)
def test_gather_constant_fold(self):
class GatherModule(torch.nn.Module):
def __init__(self):
@@ -3602,6 +3601,21 @@
x = torch.randn(1, 3, 2)
self.run_test(GatherModule(), (x,))
+ class GatherModule(torch.nn.Module):
+ def __init__(self):
+ super(GatherModule, self).__init__()
+ self.register_buffer("rb", torch.randn(1, 1, 3, 1, 1))
+
+ def forward(self, x):
+ x += self.rb[0]
+ return x
+
+ x = torch.randn(1, 3, 224, 224)
+ self.run_test(GatherModule(), (x,),
+ dynamic_axes={"input": {0: "batch", 2: "height", 3: "width"},
+ "output": {0: "batch", 1: "class", 2: "height", 3: "width"}},
+ input_names=['input'], output_names=['output'])
+
@skipIfUnsupportedOpsetVersion([13])
@skipIfUnsupportedMinOpsetVersion(9)
def test_expand(self):
diff --git a/torch/csrc/jit/passes/onnx/constant_fold.cpp b/torch/csrc/jit/passes/onnx/constant_fold.cpp
index cce5a43..8d20c78 100644
--- a/torch/csrc/jit/passes/onnx/constant_fold.cpp
+++ b/torch/csrc/jit/passes/onnx/constant_fold.cpp
@@ -476,7 +476,7 @@
// If rank of indices is 0, rank of output tensor should be
// rank_of_input - 1.
if (q < 1) {
- updated_val = updated_val.squeeze();
+ updated_val = updated_val.squeeze(axis);
}
return c10::optional<at::Tensor>(updated_val);
} else if (node->kind() == onnx::Range) {